package hub import ( "bytes" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/gorilla/websocket" "websocket-relay/internal/logging" ) // helper: start a test server with a running Hub, return the server and hub func setupTestServer(t *testing.T) (*httptest.Server, *Hub) { t.Helper() logger := logging.NewLogger("debug", &bytes.Buffer{}) h := New(logger) go h.Run() server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket)) return server, h } // helper: dial a WebSocket connection to the test server (default room) func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn { t.Helper() wsURL := "ws" + strings.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("Failed to dial WebSocket: %v", err) } return conn } // helper: dial a WebSocket connection to a specific room func dialWSWithRoom(t *testing.T, server *httptest.Server, room string) *websocket.Conn { t.Helper() wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "?room=" + room conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("Failed to dial WebSocket (room=%q): %v", room, err) } return conn } // helper: wait until hub reaches expected client count or timeout func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { if h.ClientCount() == expected { return } time.Sleep(5 * time.Millisecond) } t.Fatalf("Timed out waiting for %d clients, got %d", expected, h.ClientCount()) } func TestIntegration_SingleClientConnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn := dialWS(t, server) defer conn.Close() waitForClients(t, h, 1, time.Second) if count := h.ClientCount(); count != 1 { t.Errorf("Expected 1 client, got %d", count) } } func TestIntegration_MultipleClientsConnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() const numClients = 5 conns := make([]*websocket.Conn, numClients) for i := 0; i < numClients; i++ { conns[i] = dialWS(t, server) defer conns[i].Close() } waitForClients(t, h, numClients, time.Second) if count := h.ClientCount(); count != numClients { t.Errorf("Expected %d clients, got %d", numClients, count) } } func TestIntegration_BroadcastMessage(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() // Connect two clients conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Send a message from client 1 testMsg := "hello from client 1" if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil { t.Fatalf("Failed to send message: %v", err) } // Both clients should receive the broadcast for i, conn := range []*websocket.Conn{conn1, conn2} { conn.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn.ReadMessage() if err != nil { t.Fatalf("Client %d failed to read message: %v", i+1, err) } if string(msg) != testMsg { t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg)) } } } func TestIntegration_BroadcastToManyClients(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() const numClients = 10 conns := make([]*websocket.Conn, numClients) for i := 0; i < numClients; i++ { conns[i] = dialWS(t, server) defer conns[i].Close() } waitForClients(t, h, numClients, time.Second) // Send from first client testMsg := "broadcast to all" if err := conns[0].WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil { t.Fatalf("Failed to send message: %v", err) } // All clients should receive it for i, conn := range conns { conn.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn.ReadMessage() if err != nil { t.Fatalf("Client %d failed to read: %v", i, err) } if string(msg) != testMsg { t.Errorf("Client %d expected %q, got %q", i, testMsg, string(msg)) } } } func TestIntegration_ClientDisconnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Disconnect client 1 conn1.Close() waitForClients(t, h, 1, time.Second) if count := h.ClientCount(); count != 1 { t.Errorf("Expected 1 client after disconnect, got %d", count) } } func TestIntegration_MessageAfterDisconnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Disconnect client 1 conn1.Close() waitForClients(t, h, 1, time.Second) // Send a message from client 2 — should still work testMsg := "after disconnect" if err := conn2.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil { t.Fatalf("Failed to send message: %v", err) } // Client 2 should receive its own message back conn2.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read message: %v", err) } if string(msg) != testMsg { t.Errorf("Expected %q, got %q", testMsg, string(msg)) } } func TestIntegration_MultipleMessages(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) messages := []string{"first", "second", "third"} for _, msg := range messages { if err := conn1.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { t.Fatalf("Failed to send %q: %v", msg, err) } } // Client 2 should receive all messages in order for _, expected := range messages { conn2.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read message: %v", err) } if string(msg) != expected { t.Errorf("Expected %q, got %q", expected, string(msg)) } } } func TestIntegration_ConcurrentSenders(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() const numClients = 5 conns := make([]*websocket.Conn, numClients) for i := 0; i < numClients; i++ { conns[i] = dialWS(t, server) defer conns[i].Close() } waitForClients(t, h, numClients, time.Second) // Each client sends one message concurrently var wg sync.WaitGroup for i := 0; i < numClients; i++ { wg.Add(1) go func(idx int) { defer wg.Done() msg := []byte(strings.Repeat("x", idx+1)) // unique length per sender conns[idx].WriteMessage(websocket.TextMessage, msg) }(i) } wg.Wait() // Each client should receive exactly numClients messages (one from each sender) for i, conn := range conns { received := 0 conn.SetReadDeadline(time.Now().Add(2 * time.Second)) for received < numClients { _, _, err := conn.ReadMessage() if err != nil { t.Fatalf("Client %d: read error after %d messages: %v", i, received, err) } received++ } } } func TestIntegration_GracefulShutdownClosesClients(t *testing.T) { server, h := setupTestServer(t) defer server.Close() conn := dialWS(t, server) defer conn.Close() waitForClients(t, h, 1, time.Second) // Trigger shutdown h.Shutdown() // Client should receive a close frame conn.SetReadDeadline(time.Now().Add(time.Second)) _, _, err := conn.ReadMessage() if err == nil { t.Fatal("Expected error after shutdown, got nil") } // Verify it's a close error with GoingAway code if closeErr, ok := err.(*websocket.CloseError); ok { if closeErr.Code != websocket.CloseGoingAway { t.Errorf("Expected CloseGoingAway (%d), got %d", websocket.CloseGoingAway, closeErr.Code) } } // Any error is acceptable — the key is the connection is no longer usable } func TestIntegration_EmptyMessage(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Send an empty message if err := conn1.WriteMessage(websocket.TextMessage, []byte("")); err != nil { t.Fatalf("Failed to send empty message: %v", err) } // Client 2 should receive the empty message conn2.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read message: %v", err) } if string(msg) != "" { t.Errorf("Expected empty message, got %q", string(msg)) } } func TestIntegration_LargeMessage(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Send a 64KB message largeMsg := strings.Repeat("A", 64*1024) if err := conn1.WriteMessage(websocket.TextMessage, []byte(largeMsg)); err != nil { t.Fatalf("Failed to send large message: %v", err) } conn2.SetReadDeadline(time.Now().Add(2 * time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read large message: %v", err) } if len(msg) != 64*1024 { t.Errorf("Expected message length %d, got %d", 64*1024, len(msg)) } } func TestIntegration_RoomIsolation_MessagesOnlyGoToSameRoom(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() // Connect 2 clients to room-a, 1 client to room-b connA1 := dialWSWithRoom(t, server, "room-a") defer connA1.Close() connA2 := dialWSWithRoom(t, server, "room-a") defer connA2.Close() connB1 := dialWSWithRoom(t, server, "room-b") defer connB1.Close() waitForClients(t, h, 3, time.Second) // Send a message from client A1 testMsg := "hello room-a" if err := connA1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil { t.Fatalf("Failed to send message: %v", err) } // Both room-a clients should receive the message for i, conn := range []*websocket.Conn{connA1, connA2} { conn.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn.ReadMessage() if err != nil { t.Fatalf("Room-a client %d failed to read message: %v", i+1, err) } if string(msg) != testMsg { t.Errorf("Room-a client %d expected %q, got %q", i+1, testMsg, string(msg)) } } // Room-b client should NOT receive the message connB1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) _, _, err := connB1.ReadMessage() if err == nil { t.Fatal("Room-b client should NOT have received a message from room-a") } // Timeout error is expected — message was correctly not delivered } func TestIntegration_RoomIsolation_MultipleRoomsIndependent(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() // Connect clients to two different rooms connA := dialWSWithRoom(t, server, "alpha") defer connA.Close() connB := dialWSWithRoom(t, server, "beta") defer connB.Close() waitForClients(t, h, 2, time.Second) // Send message from room alpha msgAlpha := "alpha message" if err := connA.WriteMessage(websocket.TextMessage, []byte(msgAlpha)); err != nil { t.Fatalf("Failed to send alpha message: %v", err) } // Room alpha client receives its own message connA.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := connA.ReadMessage() if err != nil { t.Fatalf("Alpha client failed to read: %v", err) } if string(msg) != msgAlpha { t.Errorf("Alpha client expected %q, got %q", msgAlpha, string(msg)) } // Send message from room beta msgBeta := "beta message" if err := connB.WriteMessage(websocket.TextMessage, []byte(msgBeta)); err != nil { t.Fatalf("Failed to send beta message: %v", err) } // Room beta client receives its own message connB.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err = connB.ReadMessage() if err != nil { t.Fatalf("Beta client failed to read: %v", err) } if string(msg) != msgBeta { t.Errorf("Beta client expected %q, got %q", msgBeta, string(msg)) } // Verify alpha did NOT receive beta's message connA.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) _, _, err = connA.ReadMessage() if err == nil { t.Fatal("Alpha client should NOT have received beta's message") } // Verify beta did NOT receive alpha's message connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) _, _, err = connB.ReadMessage() if err == nil { t.Fatal("Beta client should NOT have received alpha's message") } } func TestIntegration_BroadcastToEmptyRoom(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() // Connect a client to room-a and a separate client to verify hub is functional connA := dialWSWithRoom(t, server, "room-a") defer connA.Close() waitForClients(t, h, 1, time.Second) // Send directly to broadcast channel targeting a non-existent room. // This should be handled gracefully (no panic, no delivery). h.broadcast <- broadcastMsg{room: "non-existent", data: []byte("ghost message")} // Give the hub time to process time.Sleep(50 * time.Millisecond) // Now send a real message to room-a to confirm hub is still functional h.broadcast <- broadcastMsg{room: "room-a", data: []byte("real message")} connA.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := connA.ReadMessage() if err != nil { t.Fatalf("Room-a client failed to read real message after empty-room broadcast: %v", err) } if string(msg) != "real message" { t.Errorf("Expected %q, got %q", "real message", string(msg)) } }