Refactor the stop case in Hub.Run() to iterate h.rooms directly
instead of h.connRoom. For each room, iterate all connections and
send CloseGoingAway frame before closing. After the loop, reset both
maps (h.rooms, h.connRoom) in one shot rather than deleting entries
incrementally. This is cleaner and avoids modifying a map during
iteration.
Add TestIntegration_GracefulShutdownMultiRoom to verify clients in
separate rooms all receive close frames during shutdown.
🤖 Assisted by the code-assist SOP
573 lines
15 KiB
Go
573 lines
15 KiB
Go
package hub
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"websocket-relay/internal/logging"
|
|
)
|
|
|
|
// helper: start a test server with a running Hub, return the server and hub
|
|
func setupTestServer(t *testing.T) (*httptest.Server, *Hub) {
|
|
t.Helper()
|
|
logger := logging.NewLogger("debug", &bytes.Buffer{})
|
|
h := New(logger)
|
|
go h.Run()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
|
return server, h
|
|
}
|
|
|
|
// helper: dial a WebSocket connection to the test server (default room)
|
|
func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
|
|
t.Helper()
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to dial WebSocket: %v", err)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
// helper: dial a WebSocket connection to a specific room
|
|
func dialWSWithRoom(t *testing.T, server *httptest.Server, room string) *websocket.Conn {
|
|
t.Helper()
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/" + room
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to dial WebSocket (room=%q): %v", room, err)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
// helper: wait until hub reaches expected client count or timeout
|
|
func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) {
|
|
t.Helper()
|
|
deadline := time.Now().Add(timeout)
|
|
for time.Now().Before(deadline) {
|
|
if h.ClientCount() == expected {
|
|
return
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
t.Fatalf("Timed out waiting for %d clients, got %d", expected, h.ClientCount())
|
|
}
|
|
|
|
func TestIntegration_SingleClientConnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn := dialWS(t, server)
|
|
defer conn.Close()
|
|
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
if count := h.ClientCount(); count != 1 {
|
|
t.Errorf("Expected 1 client, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestIntegration_MultipleClientsConnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
const numClients = 5
|
|
conns := make([]*websocket.Conn, numClients)
|
|
for i := 0; i < numClients; i++ {
|
|
conns[i] = dialWS(t, server)
|
|
defer conns[i].Close()
|
|
}
|
|
|
|
waitForClients(t, h, numClients, time.Second)
|
|
|
|
if count := h.ClientCount(); count != numClients {
|
|
t.Errorf("Expected %d clients, got %d", numClients, count)
|
|
}
|
|
}
|
|
|
|
func TestIntegration_BroadcastMessage(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
// Connect two clients
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Send a message from client 1
|
|
testMsg := "hello from client 1"
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
|
t.Fatalf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// Both clients should receive the broadcast
|
|
for i, conn := range []*websocket.Conn{conn1, conn2} {
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Client %d failed to read message: %v", i+1, err)
|
|
}
|
|
if string(msg) != testMsg {
|
|
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_BroadcastToManyClients(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
const numClients = 10
|
|
conns := make([]*websocket.Conn, numClients)
|
|
for i := 0; i < numClients; i++ {
|
|
conns[i] = dialWS(t, server)
|
|
defer conns[i].Close()
|
|
}
|
|
|
|
waitForClients(t, h, numClients, time.Second)
|
|
|
|
// Send from first client
|
|
testMsg := "broadcast to all"
|
|
if err := conns[0].WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
|
t.Fatalf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// All clients should receive it
|
|
for i, conn := range conns {
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Client %d failed to read: %v", i, err)
|
|
}
|
|
if string(msg) != testMsg {
|
|
t.Errorf("Client %d expected %q, got %q", i, testMsg, string(msg))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_ClientDisconnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Disconnect client 1
|
|
conn1.Close()
|
|
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
if count := h.ClientCount(); count != 1 {
|
|
t.Errorf("Expected 1 client after disconnect, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestIntegration_MessageAfterDisconnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Disconnect client 1
|
|
conn1.Close()
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
// Send a message from client 2 — should still work
|
|
testMsg := "after disconnect"
|
|
if err := conn2.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
|
t.Fatalf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// Client 2 should receive its own message back
|
|
conn2.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read message: %v", err)
|
|
}
|
|
if string(msg) != testMsg {
|
|
t.Errorf("Expected %q, got %q", testMsg, string(msg))
|
|
}
|
|
}
|
|
|
|
func TestIntegration_MultipleMessages(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
messages := []string{"first", "second", "third"}
|
|
|
|
for _, msg := range messages {
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
|
|
t.Fatalf("Failed to send %q: %v", msg, err)
|
|
}
|
|
}
|
|
|
|
// Client 2 should receive all messages in order
|
|
for _, expected := range messages {
|
|
conn2.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read message: %v", err)
|
|
}
|
|
if string(msg) != expected {
|
|
t.Errorf("Expected %q, got %q", expected, string(msg))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_ConcurrentSenders(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
const numClients = 5
|
|
conns := make([]*websocket.Conn, numClients)
|
|
for i := 0; i < numClients; i++ {
|
|
conns[i] = dialWS(t, server)
|
|
defer conns[i].Close()
|
|
}
|
|
|
|
waitForClients(t, h, numClients, time.Second)
|
|
|
|
// Each client sends one message concurrently
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numClients; i++ {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
msg := []byte(strings.Repeat("x", idx+1)) // unique length per sender
|
|
conns[idx].WriteMessage(websocket.TextMessage, msg)
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Each client should receive exactly numClients messages (one from each sender)
|
|
for i, conn := range conns {
|
|
received := 0
|
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
for received < numClients {
|
|
_, _, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Client %d: read error after %d messages: %v", i, received, err)
|
|
}
|
|
received++
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_GracefulShutdownClosesClients(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
|
|
conn := dialWS(t, server)
|
|
defer conn.Close()
|
|
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
// Trigger shutdown
|
|
h.Shutdown()
|
|
|
|
// Client should receive a close frame
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, _, err := conn.ReadMessage()
|
|
if err == nil {
|
|
t.Fatal("Expected error after shutdown, got nil")
|
|
}
|
|
|
|
// Verify it's a close error with GoingAway code
|
|
if closeErr, ok := err.(*websocket.CloseError); ok {
|
|
if closeErr.Code != websocket.CloseGoingAway {
|
|
t.Errorf("Expected CloseGoingAway (%d), got %d", websocket.CloseGoingAway, closeErr.Code)
|
|
}
|
|
}
|
|
// Any error is acceptable — the key is the connection is no longer usable
|
|
}
|
|
|
|
func TestIntegration_EmptyMessage(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Send an empty message
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte("")); err != nil {
|
|
t.Fatalf("Failed to send empty message: %v", err)
|
|
}
|
|
|
|
// Client 2 should receive the empty message
|
|
conn2.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read message: %v", err)
|
|
}
|
|
if string(msg) != "" {
|
|
t.Errorf("Expected empty message, got %q", string(msg))
|
|
}
|
|
}
|
|
|
|
func TestIntegration_LargeMessage(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Send a 64KB message
|
|
largeMsg := strings.Repeat("A", 64*1024)
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte(largeMsg)); err != nil {
|
|
t.Fatalf("Failed to send large message: %v", err)
|
|
}
|
|
|
|
conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read large message: %v", err)
|
|
}
|
|
if len(msg) != 64*1024 {
|
|
t.Errorf("Expected message length %d, got %d", 64*1024, len(msg))
|
|
}
|
|
}
|
|
|
|
func TestIntegration_RoomIsolation_MessagesOnlyGoToSameRoom(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
// Connect 2 clients to room-a, 1 client to room-b
|
|
connA1 := dialWSWithRoom(t, server, "room-a")
|
|
defer connA1.Close()
|
|
connA2 := dialWSWithRoom(t, server, "room-a")
|
|
defer connA2.Close()
|
|
connB1 := dialWSWithRoom(t, server, "room-b")
|
|
defer connB1.Close()
|
|
|
|
waitForClients(t, h, 3, time.Second)
|
|
|
|
// Send a message from client A1
|
|
testMsg := "hello room-a"
|
|
if err := connA1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
|
t.Fatalf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// Both room-a clients should receive the message
|
|
for i, conn := range []*websocket.Conn{connA1, connA2} {
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Room-a client %d failed to read message: %v", i+1, err)
|
|
}
|
|
if string(msg) != testMsg {
|
|
t.Errorf("Room-a client %d expected %q, got %q", i+1, testMsg, string(msg))
|
|
}
|
|
}
|
|
|
|
// Room-b client should NOT receive the message
|
|
connB1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
|
_, _, err := connB1.ReadMessage()
|
|
if err == nil {
|
|
t.Fatal("Room-b client should NOT have received a message from room-a")
|
|
}
|
|
// Timeout error is expected — message was correctly not delivered
|
|
}
|
|
|
|
func TestIntegration_RoomIsolation_MultipleRoomsIndependent(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
// Connect clients to two different rooms
|
|
connA := dialWSWithRoom(t, server, "alpha")
|
|
defer connA.Close()
|
|
connB := dialWSWithRoom(t, server, "beta")
|
|
defer connB.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Send message from room alpha
|
|
msgAlpha := "alpha message"
|
|
if err := connA.WriteMessage(websocket.TextMessage, []byte(msgAlpha)); err != nil {
|
|
t.Fatalf("Failed to send alpha message: %v", err)
|
|
}
|
|
|
|
// Room alpha client receives its own message
|
|
connA.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := connA.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Alpha client failed to read: %v", err)
|
|
}
|
|
if string(msg) != msgAlpha {
|
|
t.Errorf("Alpha client expected %q, got %q", msgAlpha, string(msg))
|
|
}
|
|
|
|
// Send message from room beta
|
|
msgBeta := "beta message"
|
|
if err := connB.WriteMessage(websocket.TextMessage, []byte(msgBeta)); err != nil {
|
|
t.Fatalf("Failed to send beta message: %v", err)
|
|
}
|
|
|
|
// Room beta client receives its own message
|
|
connB.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err = connB.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Beta client failed to read: %v", err)
|
|
}
|
|
if string(msg) != msgBeta {
|
|
t.Errorf("Beta client expected %q, got %q", msgBeta, string(msg))
|
|
}
|
|
|
|
// Verify alpha did NOT receive beta's message
|
|
connA.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
|
_, _, err = connA.ReadMessage()
|
|
if err == nil {
|
|
t.Fatal("Alpha client should NOT have received beta's message")
|
|
}
|
|
|
|
// Verify beta did NOT receive alpha's message
|
|
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
|
_, _, err = connB.ReadMessage()
|
|
if err == nil {
|
|
t.Fatal("Beta client should NOT have received alpha's message")
|
|
}
|
|
}
|
|
|
|
func TestIntegration_BroadcastToEmptyRoom(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
// Connect a client to room-a and a separate client to verify hub is functional
|
|
connA := dialWSWithRoom(t, server, "room-a")
|
|
defer connA.Close()
|
|
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
// Send directly to broadcast channel targeting a non-existent room.
|
|
// This should be handled gracefully (no panic, no delivery).
|
|
h.broadcast <- broadcastMsg{room: "/non-existent", data: []byte("ghost message")}
|
|
|
|
// Give the hub time to process
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Now send a real message to room-a to confirm hub is still functional
|
|
h.broadcast <- broadcastMsg{room: "/room-a", data: []byte("real message")}
|
|
|
|
connA.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := connA.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Room-a client failed to read real message after empty-room broadcast: %v", err)
|
|
}
|
|
if string(msg) != "real message" {
|
|
t.Errorf("Expected %q, got %q", "real message", string(msg))
|
|
}
|
|
}
|
|
|
|
func TestIntegration_GracefulShutdownMultiRoom(t *testing.T) {
|
|
logger := logging.NewLogger("debug", &bytes.Buffer{})
|
|
h := New(logger)
|
|
go h.Run()
|
|
|
|
s := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
|
defer s.Close()
|
|
|
|
// Connect clients to different rooms
|
|
rooms := []string{"room-a", "room-b", "room-c"}
|
|
conns := make([]*websocket.Conn, 0, len(rooms))
|
|
for _, room := range rooms {
|
|
ws := dialWSWithRoom(t, s, room)
|
|
conns = append(conns, ws)
|
|
}
|
|
|
|
waitForClients(t, h, 3, time.Second)
|
|
|
|
// Verify all clients are connected to separate rooms
|
|
h.mu.RLock()
|
|
roomCount := len(h.rooms)
|
|
connCount := len(h.connRoom)
|
|
h.mu.RUnlock()
|
|
if roomCount != 3 {
|
|
t.Fatalf("expected 3 rooms, got %d", roomCount)
|
|
}
|
|
if connCount != 3 {
|
|
t.Fatalf("expected 3 connections, got %d", connCount)
|
|
}
|
|
|
|
// Shutdown the hub — all clients should receive close frame
|
|
h.Shutdown()
|
|
|
|
var wg sync.WaitGroup
|
|
for i, ws := range conns {
|
|
wg.Add(1)
|
|
go func(idx int, c *websocket.Conn) {
|
|
defer wg.Done()
|
|
c.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, _, err := c.ReadMessage()
|
|
if err == nil {
|
|
t.Errorf("client %d: expected error (close frame), got nil", idx)
|
|
return
|
|
}
|
|
closeErr, ok := err.(*websocket.CloseError)
|
|
if !ok {
|
|
t.Errorf("client %d: expected CloseError, got %T: %v", idx, err, err)
|
|
return
|
|
}
|
|
if closeErr.Code != websocket.CloseGoingAway {
|
|
t.Errorf("client %d: expected CloseGoingAway (%d), got %d",
|
|
idx, websocket.CloseGoingAway, closeErr.Code)
|
|
}
|
|
}(i, ws)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Verify maps are cleared
|
|
h.mu.RLock()
|
|
roomsAfter := len(h.rooms)
|
|
connsAfter := len(h.connRoom)
|
|
h.mu.RUnlock()
|
|
if roomsAfter != 0 {
|
|
t.Errorf("expected 0 rooms after shutdown, got %d", roomsAfter)
|
|
}
|
|
if connsAfter != 0 {
|
|
t.Errorf("expected 0 connections after shutdown, got %d", connsAfter)
|
|
}
|
|
}
|