feat(websocket): add message types, connection hub, and connection handler
Add the core WebSocket infrastructure: - Message type definitions for the wire protocol (subscribe, unsubscribe, auth, error, push events) - In-memory connection hub that tracks per-user connections and routes messages to subscribed clients - Connection wrapper with auth-after-connect flow: connections start unauthenticated, client sends JWT as first message, only then can subscribe to event topics Includes auth timeout (30s), shared cancellation context for read/write loops, hub map cleanup on last connection removal, and proper error delivery before closing on auth failure.
This commit is contained in:
parent
4f9355c915
commit
9255fe07a9
|
|
@ -0,0 +1,271 @@
|
||||||
|
// Vikunja is a to-do list application to facilitate your life.
|
||||||
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.vikunja.io/api/pkg/log"
|
||||||
|
"code.vikunja.io/api/pkg/modules/auth"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
writeTimeout = 10 * time.Second
|
||||||
|
pingInterval = 30 * time.Second
|
||||||
|
authTimeout = 30 * time.Second
|
||||||
|
sendBufSize = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connection wraps a single WebSocket connection.
|
||||||
|
type Connection struct {
|
||||||
|
ws *websocket.Conn
|
||||||
|
hub *Hub
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
userID int64
|
||||||
|
authenticated bool
|
||||||
|
subscriptions map[string]bool
|
||||||
|
|
||||||
|
send chan OutgoingMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnection creates a new unauthenticated Connection.
|
||||||
|
func NewConnection(ws *websocket.Conn, hub *Hub) *Connection {
|
||||||
|
return &Connection{
|
||||||
|
ws: ws,
|
||||||
|
hub: hub,
|
||||||
|
authenticated: false,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, sendBufSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe adds an event subscription.
|
||||||
|
func (c *Connection) Subscribe(event string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.subscriptions[event] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes an event subscription.
|
||||||
|
func (c *Connection) Unsubscribe(event string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
delete(c.subscriptions, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSubscribed checks if the connection is subscribed to an event.
|
||||||
|
func (c *Connection) IsSubscribed(event string) bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.subscriptions[event]
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticated returns whether the connection is authenticated.
|
||||||
|
func (c *Connection) IsAuthenticated() bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.authenticated
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserID returns the authenticated user's ID.
|
||||||
|
func (c *Connection) UserID() int64 {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.userID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadLoop reads messages from the WebSocket and handles auth/subscribe/unsubscribe.
|
||||||
|
func (c *Connection) ReadLoop(ctx context.Context, cancel context.CancelFunc) {
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
if c.IsAuthenticated() {
|
||||||
|
c.hub.Unregister(c)
|
||||||
|
}
|
||||||
|
c.ws.Close(websocket.StatusNormalClosure, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Close the connection if auth doesn't happen within the timeout
|
||||||
|
authTimer := time.AfterFunc(authTimeout, func() {
|
||||||
|
if !c.IsAuthenticated() {
|
||||||
|
log.Debugf("WebSocket: closing unauthenticated connection after timeout")
|
||||||
|
c.ws.Close(websocket.StatusPolicyViolation, "auth timeout")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer authTimer.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, data, err := c.ws.Read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
|
||||||
|
log.Debugf("WebSocket: connection closed normally for user %d", c.UserID())
|
||||||
|
} else {
|
||||||
|
log.Debugf("WebSocket: read error for user %d: %v", c.UserID(), err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg IncomingMessage
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
log.Warningf("WebSocket: invalid message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.handleMessage(ctx, msg) {
|
||||||
|
return // close connection
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMessage processes an incoming message. Returns false if connection should be closed.
|
||||||
|
func (c *Connection) handleMessage(ctx context.Context, msg IncomingMessage) bool {
|
||||||
|
switch msg.Action {
|
||||||
|
case ActionAuth:
|
||||||
|
return c.handleAuth(ctx, msg.Token)
|
||||||
|
case ActionSubscribe:
|
||||||
|
if !c.IsAuthenticated() {
|
||||||
|
c.sendError("auth_required", "")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !isValidEvent(msg.Event) {
|
||||||
|
c.sendError("invalid_event", msg.Event)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Subscribe(msg.Event)
|
||||||
|
log.Debugf("WebSocket: user %d subscribed to %s", c.UserID(), msg.Event)
|
||||||
|
case ActionUnsubscribe:
|
||||||
|
if !c.IsAuthenticated() {
|
||||||
|
c.sendError("auth_required", "")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Unsubscribe(msg.Event)
|
||||||
|
log.Debugf("WebSocket: user %d unsubscribed from %s", c.UserID(), msg.Event)
|
||||||
|
default:
|
||||||
|
log.Warningf("WebSocket: unknown action %q", msg.Action)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) handleAuth(ctx context.Context, token string) bool {
|
||||||
|
if c.IsAuthenticated() {
|
||||||
|
c.sendError("already_authenticated", "")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := auth.GetUserIDFromToken(token)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("WebSocket: auth failed: %v", err)
|
||||||
|
// Write the error directly to the websocket since ReadLoop will close the
|
||||||
|
// connection immediately after we return false, before WriteLoop can drain the channel.
|
||||||
|
c.writeMessageDirect(ctx, OutgoingMessage{Error: "invalid_token"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.userID = userID
|
||||||
|
c.authenticated = true
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
c.hub.Register(c)
|
||||||
|
|
||||||
|
// Send auth success
|
||||||
|
select {
|
||||||
|
case c.send <- OutgoingMessage{Action: ActionAuthSuccess, Success: true}:
|
||||||
|
default:
|
||||||
|
log.Warningf("WebSocket: send buffer full for user %d", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("WebSocket: user %d authenticated", userID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeMessageDirect writes a message directly to the websocket, bypassing the send channel.
|
||||||
|
// Use this when the message must be sent before the connection is closed.
|
||||||
|
func (c *Connection) writeMessageDirect(ctx context.Context, msg OutgoingMessage) {
|
||||||
|
writeCtx, cancel := context.WithTimeout(ctx, writeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("WebSocket: marshal error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := c.ws.Write(writeCtx, websocket.MessageText, data); err != nil {
|
||||||
|
log.Debugf("WebSocket: direct write error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) sendError(errMsg, event string) {
|
||||||
|
select {
|
||||||
|
case c.send <- OutgoingMessage{Error: errMsg, Event: event}:
|
||||||
|
default:
|
||||||
|
log.Warningf("WebSocket: send buffer full, dropping error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteLoop drains the send channel and writes messages to the WebSocket.
|
||||||
|
// It also sends periodic pings.
|
||||||
|
func (c *Connection) WriteLoop(ctx context.Context, cancel context.CancelFunc) {
|
||||||
|
defer cancel()
|
||||||
|
ticker := time.NewTicker(pingInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg, ok := <-c.send:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeCtx, cancel := context.WithTimeout(ctx, writeTimeout)
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
log.Errorf("WebSocket: marshal error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = c.ws.Write(writeCtx, websocket.MessageText, data)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("WebSocket: write error for user %d: %v", c.UserID(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ticker.C:
|
||||||
|
pingCtx, cancel := context.WithTimeout(ctx, writeTimeout)
|
||||||
|
err := c.ws.Ping(pingCtx)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("WebSocket: ping error for user %d: %v", c.UserID(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validEvents is the set of event names clients are allowed to subscribe to.
|
||||||
|
var validEvents = map[string]bool{
|
||||||
|
"notification.created": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidEvent(event string) bool {
|
||||||
|
return validEvents[event]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
// Vikunja is a to-do list application to facilitate your life.
|
||||||
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnectionSubscribeUnsubscribe(t *testing.T) {
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 1,
|
||||||
|
authenticated: true,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Subscribe("notification.created")
|
||||||
|
assert.True(t, conn.IsSubscribed("notification.created"))
|
||||||
|
|
||||||
|
conn.Unsubscribe("notification.created")
|
||||||
|
assert.False(t, conn.IsSubscribed("notification.created"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionIsSubscribedReturnsFalseForUnknownEvent(t *testing.T) {
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 1,
|
||||||
|
authenticated: true,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, conn.IsSubscribed("something"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionAcceptsValidEvent(t *testing.T) {
|
||||||
|
hub := NewHub()
|
||||||
|
conn := &Connection{
|
||||||
|
hub: hub,
|
||||||
|
userID: 1,
|
||||||
|
authenticated: true,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
hub.Register(conn)
|
||||||
|
|
||||||
|
conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notification.created"})
|
||||||
|
|
||||||
|
assert.True(t, conn.IsSubscribed("notification.created"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionRejectsInvalidEvent(t *testing.T) {
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 1,
|
||||||
|
authenticated: true,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notifications"})
|
||||||
|
|
||||||
|
msg := <-conn.send
|
||||||
|
assert.Equal(t, "invalid_event", msg.Error)
|
||||||
|
assert.False(t, conn.IsSubscribed("notifications"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionRejectsActionsBeforeAuth(t *testing.T) {
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 0, // not authenticated
|
||||||
|
authenticated: false,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to subscribe before auth - should be rejected
|
||||||
|
conn.handleMessage(context.Background(), IncomingMessage{Action: ActionSubscribe, Event: "notification.created"})
|
||||||
|
|
||||||
|
// Should have sent an error
|
||||||
|
msg := <-conn.send
|
||||||
|
assert.Equal(t, "auth_required", msg.Error)
|
||||||
|
assert.False(t, conn.IsSubscribed("notification.created"))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
// Vikunja is a to-do list application to facilitate your life.
|
||||||
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"code.vikunja.io/api/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hub maintains the set of active connections and delivers messages to them.
|
||||||
|
type Hub struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
connections map[int64][]*Connection // userID -> connections
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHub creates a new Hub.
|
||||||
|
func NewHub() *Hub {
|
||||||
|
return &Hub{
|
||||||
|
connections: make(map[int64][]*Connection),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a connection to the hub.
|
||||||
|
func (h *Hub) Register(conn *Connection) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.connections[conn.userID] = append(h.connections[conn.userID], conn)
|
||||||
|
log.Debugf("WebSocket: registered connection for user %d (total: %d)", conn.userID, len(h.connections[conn.userID]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister removes a connection from the hub.
|
||||||
|
func (h *Hub) Unregister(conn *Connection) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
conns := h.connections[conn.userID]
|
||||||
|
for i, c := range conns {
|
||||||
|
if c == conn {
|
||||||
|
h.connections[conn.userID] = append(conns[:i], conns[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
remaining := len(h.connections[conn.userID])
|
||||||
|
if remaining == 0 {
|
||||||
|
delete(h.connections, conn.userID)
|
||||||
|
}
|
||||||
|
log.Debugf("WebSocket: unregistered connection for user %d (remaining: %d)", conn.userID, remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishForUser sends an event to all connections of a specific user that are subscribed to the given event.
|
||||||
|
func (h *Hub) PublishForUser(userID int64, event string, data any) {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
conns := h.connections[userID]
|
||||||
|
msg := OutgoingMessage{
|
||||||
|
Event: event,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, conn := range conns {
|
||||||
|
if !conn.IsSubscribed(event) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case conn.send <- msg:
|
||||||
|
default:
|
||||||
|
log.Warningf("WebSocket: send buffer full for user %d, dropping message", userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
// Vikunja is a to-do list application to facilitate your life.
|
||||||
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHubRegisterUnregister(t *testing.T) {
|
||||||
|
h := NewHub()
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 1,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
h.Register(conn)
|
||||||
|
assert.Len(t, h.connections[1], 1)
|
||||||
|
|
||||||
|
h.Unregister(conn)
|
||||||
|
assert.Empty(t, h.connections[1])
|
||||||
|
_, exists := h.connections[1]
|
||||||
|
assert.False(t, exists, "map entry should be deleted when last connection is removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHubPublishToSubscribedConnection(t *testing.T) {
|
||||||
|
h := NewHub()
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 1,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
h.Register(conn)
|
||||||
|
conn.subscriptions["notification.created"] = true
|
||||||
|
|
||||||
|
h.PublishForUser(1, "notification.created", map[string]string{"id": "1"})
|
||||||
|
|
||||||
|
msg := <-conn.send
|
||||||
|
assert.Equal(t, "notification.created", msg.Event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHubPublishSkipsUnsubscribedConnection(t *testing.T) {
|
||||||
|
h := NewHub()
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 1,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
h.Register(conn)
|
||||||
|
// Not subscribed to "notification.created"
|
||||||
|
|
||||||
|
h.PublishForUser(1, "notification.created", map[string]string{"id": "1"})
|
||||||
|
|
||||||
|
assert.Empty(t, conn.send)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHubPublishSkipsOtherUsers(t *testing.T) {
|
||||||
|
h := NewHub()
|
||||||
|
conn := &Connection{
|
||||||
|
userID: 2,
|
||||||
|
subscriptions: make(map[string]bool),
|
||||||
|
send: make(chan OutgoingMessage, 16),
|
||||||
|
}
|
||||||
|
h.Register(conn)
|
||||||
|
conn.subscriptions["notification.created"] = true
|
||||||
|
|
||||||
|
h.PublishForUser(1, "notification.created", map[string]string{"id": "1"})
|
||||||
|
|
||||||
|
assert.Empty(t, conn.send)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
// Vikunja is a to-do list application to facilitate your life.
|
||||||
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"code.vikunja.io/api/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
log.InitLogger()
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
// Vikunja is a to-do list application to facilitate your life.
|
||||||
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Client actions
|
||||||
|
ActionAuth = "auth"
|
||||||
|
ActionSubscribe = "subscribe"
|
||||||
|
ActionUnsubscribe = "unsubscribe"
|
||||||
|
|
||||||
|
// Server actions
|
||||||
|
ActionAuthSuccess = "auth.success"
|
||||||
|
ActionUnsubscribed = "unsubscribed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IncomingMessage represents a message from the client.
|
||||||
|
type IncomingMessage struct {
|
||||||
|
Action string `json:"action"`
|
||||||
|
// Token is set for auth action.
|
||||||
|
Token string `json:"token,omitempty"`
|
||||||
|
// Event is set for subscribe/unsubscribe actions.
|
||||||
|
Event string `json:"event,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OutgoingMessage represents a message from the server to the client.
|
||||||
|
// Exactly one of Event, Error, or Action will be set.
|
||||||
|
type OutgoingMessage struct {
|
||||||
|
// Event identifies the event type. On push messages (e.g. "notification.created"),
|
||||||
|
// it carries the event name. On error responses for subscribe/unsubscribe,
|
||||||
|
// it identifies which event caused the error.
|
||||||
|
Event string `json:"event,omitempty"`
|
||||||
|
// Error is set for error responses (e.g. "forbidden").
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
// Action is set for server-initiated actions (e.g. "auth.success", "unsubscribed").
|
||||||
|
Action string `json:"action,omitempty"`
|
||||||
|
// Success is set for auth.success action.
|
||||||
|
Success bool `json:"success,omitempty"`
|
||||||
|
// Reason provides context for server-initiated actions.
|
||||||
|
Reason string `json:"reason,omitempty"`
|
||||||
|
// Data carries the event payload.
|
||||||
|
Data any `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
// Vikunja is a to-do list application to facilitate your life.
|
||||||
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIncomingAuthMessageDeserialization(t *testing.T) {
|
||||||
|
raw := `{"action":"auth","token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}`
|
||||||
|
var msg IncomingMessage
|
||||||
|
err := json.Unmarshal([]byte(raw), &msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, ActionAuth, msg.Action)
|
||||||
|
assert.Equal(t, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", msg.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIncomingSubscribeMessageDeserialization(t *testing.T) {
|
||||||
|
raw := `{"action":"subscribe","event":"notification.created"}`
|
||||||
|
var msg IncomingMessage
|
||||||
|
err := json.Unmarshal([]byte(raw), &msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, ActionSubscribe, msg.Action)
|
||||||
|
assert.Equal(t, "notification.created", msg.Event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutgoingEventSerialization(t *testing.T) {
|
||||||
|
msg := OutgoingMessage{
|
||||||
|
Event: "notification.created",
|
||||||
|
Data: map[string]string{"hello": "world"},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"event":"notification.created"`)
|
||||||
|
assert.NotContains(t, string(data), `"topic"`)
|
||||||
|
assert.Contains(t, string(data), `"hello":"world"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutgoingErrorSerialization(t *testing.T) {
|
||||||
|
msg := OutgoingMessage{
|
||||||
|
Error: "forbidden",
|
||||||
|
Event: "project.tasks",
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"error":"forbidden"`)
|
||||||
|
assert.Contains(t, string(data), `"event":"project.tasks"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutgoingAuthSuccessSerialization(t *testing.T) {
|
||||||
|
msg := OutgoingMessage{
|
||||||
|
Action: ActionAuthSuccess,
|
||||||
|
Success: true,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"action":"auth.success"`)
|
||||||
|
assert.Contains(t, string(data), `"success":true`)
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue