websocket-relay/internal/hub/hub_integration_test.go
savinmax c226bdab7f feat(hub): update graceful shutdown to iterate rooms for multi-room cleanup
Refactor the stop case in Hub.Run() to iterate h.rooms directly
instead of h.connRoom. For each room, iterate all connections and
send CloseGoingAway frame before closing. After the loop, reset both
maps (h.rooms, h.connRoom) in one shot rather than deleting entries
incrementally. This is cleaner and avoids modifying a map during
iteration.

Add TestIntegration_GracefulShutdownMultiRoom to verify clients in
separate rooms all receive close frames during shutdown.

🤖 Assisted by the code-assist SOP
2026-06-13 13:26:03 +02:00

573 lines
15 KiB
Go

package hub
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"websocket-relay/internal/logging"
)
// helper: start a test server with a running Hub, return the server and hub
func setupTestServer(t *testing.T) (*httptest.Server, *Hub) {
t.Helper()
logger := logging.NewLogger("debug", &bytes.Buffer{})
h := New(logger)
go h.Run()
server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
return server, h
}
// helper: dial a WebSocket connection to the test server (default room)
func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to dial WebSocket: %v", err)
}
return conn
}
// helper: dial a WebSocket connection to a specific room
func dialWSWithRoom(t *testing.T, server *httptest.Server, room string) *websocket.Conn {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/" + room
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to dial WebSocket (room=%q): %v", room, err)
}
return conn
}
// helper: wait until hub reaches expected client count or timeout
func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if h.ClientCount() == expected {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("Timed out waiting for %d clients, got %d", expected, h.ClientCount())
}
func TestIntegration_SingleClientConnect(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
conn := dialWS(t, server)
defer conn.Close()
waitForClients(t, h, 1, time.Second)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client, got %d", count)
}
}
func TestIntegration_MultipleClientsConnect(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
const numClients = 5
conns := make([]*websocket.Conn, numClients)
for i := 0; i < numClients; i++ {
conns[i] = dialWS(t, server)
defer conns[i].Close()
}
waitForClients(t, h, numClients, time.Second)
if count := h.ClientCount(); count != numClients {
t.Errorf("Expected %d clients, got %d", numClients, count)
}
}
func TestIntegration_BroadcastMessage(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect two clients
conn1 := dialWS(t, server)
defer conn1.Close()
conn2 := dialWS(t, server)
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Send a message from client 1
testMsg := "hello from client 1"
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// Both clients should receive the broadcast
for i, conn := range []*websocket.Conn{conn1, conn2} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read message: %v", i+1, err)
}
if string(msg) != testMsg {
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
}
}
}
func TestIntegration_BroadcastToManyClients(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
const numClients = 10
conns := make([]*websocket.Conn, numClients)
for i := 0; i < numClients; i++ {
conns[i] = dialWS(t, server)
defer conns[i].Close()
}
waitForClients(t, h, numClients, time.Second)
// Send from first client
testMsg := "broadcast to all"
if err := conns[0].WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// All clients should receive it
for i, conn := range conns {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read: %v", i, err)
}
if string(msg) != testMsg {
t.Errorf("Client %d expected %q, got %q", i, testMsg, string(msg))
}
}
}
func TestIntegration_ClientDisconnect(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
conn1 := dialWS(t, server)
conn2 := dialWS(t, server)
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Disconnect client 1
conn1.Close()
waitForClients(t, h, 1, time.Second)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client after disconnect, got %d", count)
}
}
func TestIntegration_MessageAfterDisconnect(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
conn1 := dialWS(t, server)
conn2 := dialWS(t, server)
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Disconnect client 1
conn1.Close()
waitForClients(t, h, 1, time.Second)
// Send a message from client 2 — should still work
testMsg := "after disconnect"
if err := conn2.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// Client 2 should receive its own message back
conn2.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn2.ReadMessage()
if err != nil {
t.Fatalf("Failed to read message: %v", err)
}
if string(msg) != testMsg {
t.Errorf("Expected %q, got %q", testMsg, string(msg))
}
}
func TestIntegration_MultipleMessages(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
conn1 := dialWS(t, server)
defer conn1.Close()
conn2 := dialWS(t, server)
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
messages := []string{"first", "second", "third"}
for _, msg := range messages {
if err := conn1.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
t.Fatalf("Failed to send %q: %v", msg, err)
}
}
// Client 2 should receive all messages in order
for _, expected := range messages {
conn2.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn2.ReadMessage()
if err != nil {
t.Fatalf("Failed to read message: %v", err)
}
if string(msg) != expected {
t.Errorf("Expected %q, got %q", expected, string(msg))
}
}
}
func TestIntegration_ConcurrentSenders(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
const numClients = 5
conns := make([]*websocket.Conn, numClients)
for i := 0; i < numClients; i++ {
conns[i] = dialWS(t, server)
defer conns[i].Close()
}
waitForClients(t, h, numClients, time.Second)
// Each client sends one message concurrently
var wg sync.WaitGroup
for i := 0; i < numClients; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
msg := []byte(strings.Repeat("x", idx+1)) // unique length per sender
conns[idx].WriteMessage(websocket.TextMessage, msg)
}(i)
}
wg.Wait()
// Each client should receive exactly numClients messages (one from each sender)
for i, conn := range conns {
received := 0
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
for received < numClients {
_, _, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d: read error after %d messages: %v", i, received, err)
}
received++
}
}
}
func TestIntegration_GracefulShutdownClosesClients(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
conn := dialWS(t, server)
defer conn.Close()
waitForClients(t, h, 1, time.Second)
// Trigger shutdown
h.Shutdown()
// Client should receive a close frame
conn.SetReadDeadline(time.Now().Add(time.Second))
_, _, err := conn.ReadMessage()
if err == nil {
t.Fatal("Expected error after shutdown, got nil")
}
// Verify it's a close error with GoingAway code
if closeErr, ok := err.(*websocket.CloseError); ok {
if closeErr.Code != websocket.CloseGoingAway {
t.Errorf("Expected CloseGoingAway (%d), got %d", websocket.CloseGoingAway, closeErr.Code)
}
}
// Any error is acceptable — the key is the connection is no longer usable
}
func TestIntegration_EmptyMessage(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
conn1 := dialWS(t, server)
defer conn1.Close()
conn2 := dialWS(t, server)
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Send an empty message
if err := conn1.WriteMessage(websocket.TextMessage, []byte("")); err != nil {
t.Fatalf("Failed to send empty message: %v", err)
}
// Client 2 should receive the empty message
conn2.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn2.ReadMessage()
if err != nil {
t.Fatalf("Failed to read message: %v", err)
}
if string(msg) != "" {
t.Errorf("Expected empty message, got %q", string(msg))
}
}
func TestIntegration_LargeMessage(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
conn1 := dialWS(t, server)
defer conn1.Close()
conn2 := dialWS(t, server)
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Send a 64KB message
largeMsg := strings.Repeat("A", 64*1024)
if err := conn1.WriteMessage(websocket.TextMessage, []byte(largeMsg)); err != nil {
t.Fatalf("Failed to send large message: %v", err)
}
conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
_, msg, err := conn2.ReadMessage()
if err != nil {
t.Fatalf("Failed to read large message: %v", err)
}
if len(msg) != 64*1024 {
t.Errorf("Expected message length %d, got %d", 64*1024, len(msg))
}
}
func TestIntegration_RoomIsolation_MessagesOnlyGoToSameRoom(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect 2 clients to room-a, 1 client to room-b
connA1 := dialWSWithRoom(t, server, "room-a")
defer connA1.Close()
connA2 := dialWSWithRoom(t, server, "room-a")
defer connA2.Close()
connB1 := dialWSWithRoom(t, server, "room-b")
defer connB1.Close()
waitForClients(t, h, 3, time.Second)
// Send a message from client A1
testMsg := "hello room-a"
if err := connA1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// Both room-a clients should receive the message
for i, conn := range []*websocket.Conn{connA1, connA2} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Room-a client %d failed to read message: %v", i+1, err)
}
if string(msg) != testMsg {
t.Errorf("Room-a client %d expected %q, got %q", i+1, testMsg, string(msg))
}
}
// Room-b client should NOT receive the message
connB1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err := connB1.ReadMessage()
if err == nil {
t.Fatal("Room-b client should NOT have received a message from room-a")
}
// Timeout error is expected — message was correctly not delivered
}
func TestIntegration_RoomIsolation_MultipleRoomsIndependent(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect clients to two different rooms
connA := dialWSWithRoom(t, server, "alpha")
defer connA.Close()
connB := dialWSWithRoom(t, server, "beta")
defer connB.Close()
waitForClients(t, h, 2, time.Second)
// Send message from room alpha
msgAlpha := "alpha message"
if err := connA.WriteMessage(websocket.TextMessage, []byte(msgAlpha)); err != nil {
t.Fatalf("Failed to send alpha message: %v", err)
}
// Room alpha client receives its own message
connA.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := connA.ReadMessage()
if err != nil {
t.Fatalf("Alpha client failed to read: %v", err)
}
if string(msg) != msgAlpha {
t.Errorf("Alpha client expected %q, got %q", msgAlpha, string(msg))
}
// Send message from room beta
msgBeta := "beta message"
if err := connB.WriteMessage(websocket.TextMessage, []byte(msgBeta)); err != nil {
t.Fatalf("Failed to send beta message: %v", err)
}
// Room beta client receives its own message
connB.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err = connB.ReadMessage()
if err != nil {
t.Fatalf("Beta client failed to read: %v", err)
}
if string(msg) != msgBeta {
t.Errorf("Beta client expected %q, got %q", msgBeta, string(msg))
}
// Verify alpha did NOT receive beta's message
connA.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err = connA.ReadMessage()
if err == nil {
t.Fatal("Alpha client should NOT have received beta's message")
}
// Verify beta did NOT receive alpha's message
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err = connB.ReadMessage()
if err == nil {
t.Fatal("Beta client should NOT have received alpha's message")
}
}
func TestIntegration_BroadcastToEmptyRoom(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect a client to room-a and a separate client to verify hub is functional
connA := dialWSWithRoom(t, server, "room-a")
defer connA.Close()
waitForClients(t, h, 1, time.Second)
// Send directly to broadcast channel targeting a non-existent room.
// This should be handled gracefully (no panic, no delivery).
h.broadcast <- broadcastMsg{room: "/non-existent", data: []byte("ghost message")}
// Give the hub time to process
time.Sleep(50 * time.Millisecond)
// Now send a real message to room-a to confirm hub is still functional
h.broadcast <- broadcastMsg{room: "/room-a", data: []byte("real message")}
connA.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := connA.ReadMessage()
if err != nil {
t.Fatalf("Room-a client failed to read real message after empty-room broadcast: %v", err)
}
if string(msg) != "real message" {
t.Errorf("Expected %q, got %q", "real message", string(msg))
}
}
func TestIntegration_GracefulShutdownMultiRoom(t *testing.T) {
logger := logging.NewLogger("debug", &bytes.Buffer{})
h := New(logger)
go h.Run()
s := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
defer s.Close()
// Connect clients to different rooms
rooms := []string{"room-a", "room-b", "room-c"}
conns := make([]*websocket.Conn, 0, len(rooms))
for _, room := range rooms {
ws := dialWSWithRoom(t, s, room)
conns = append(conns, ws)
}
waitForClients(t, h, 3, time.Second)
// Verify all clients are connected to separate rooms
h.mu.RLock()
roomCount := len(h.rooms)
connCount := len(h.connRoom)
h.mu.RUnlock()
if roomCount != 3 {
t.Fatalf("expected 3 rooms, got %d", roomCount)
}
if connCount != 3 {
t.Fatalf("expected 3 connections, got %d", connCount)
}
// Shutdown the hub — all clients should receive close frame
h.Shutdown()
var wg sync.WaitGroup
for i, ws := range conns {
wg.Add(1)
go func(idx int, c *websocket.Conn) {
defer wg.Done()
c.SetReadDeadline(time.Now().Add(2 * time.Second))
_, _, err := c.ReadMessage()
if err == nil {
t.Errorf("client %d: expected error (close frame), got nil", idx)
return
}
closeErr, ok := err.(*websocket.CloseError)
if !ok {
t.Errorf("client %d: expected CloseError, got %T: %v", idx, err, err)
return
}
if closeErr.Code != websocket.CloseGoingAway {
t.Errorf("client %d: expected CloseGoingAway (%d), got %d",
idx, websocket.CloseGoingAway, closeErr.Code)
}
}(i, ws)
}
wg.Wait()
// Verify maps are cleared
h.mu.RLock()
roomsAfter := len(h.rooms)
connsAfter := len(h.connRoom)
h.mu.RUnlock()
if roomsAfter != 0 {
t.Errorf("expected 0 rooms after shutdown, got %d", roomsAfter)
}
if connsAfter != 0 {
t.Errorf("expected 0 connections after shutdown, got %d", connsAfter)
}
}