package hub import ( "bytes" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/gorilla/websocket" "websocket-relay/internal/logging" ) // helper: start a test server with a running Hub, return the server and hub func setupTestServer(t *testing.T) (*httptest.Server, *Hub) { t.Helper() logger := logging.NewLogger("debug", &bytes.Buffer{}) h := New(logger) go h.Run() server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket)) return server, h } // helper: dial a WebSocket connection to the test server func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn { t.Helper() wsURL := "ws" + strings.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("Failed to dial WebSocket: %v", err) } return conn } // helper: wait until hub reaches expected client count or timeout func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { if h.ClientCount() == expected { return } time.Sleep(5 * time.Millisecond) } t.Fatalf("Timed out waiting for %d clients, got %d", expected, h.ClientCount()) } func TestIntegration_SingleClientConnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn := dialWS(t, server) defer conn.Close() waitForClients(t, h, 1, time.Second) if count := h.ClientCount(); count != 1 { t.Errorf("Expected 1 client, got %d", count) } } func TestIntegration_MultipleClientsConnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() const numClients = 5 conns := make([]*websocket.Conn, numClients) for i := 0; i < numClients; i++ { conns[i] = dialWS(t, server) defer conns[i].Close() } waitForClients(t, h, numClients, time.Second) if count := h.ClientCount(); count != numClients { t.Errorf("Expected %d clients, got %d", numClients, count) } } func TestIntegration_BroadcastMessage(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() // Connect two clients conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Send a message from client 1 testMsg := "hello from client 1" if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil { t.Fatalf("Failed to send message: %v", err) } // Both clients should receive the broadcast for i, conn := range []*websocket.Conn{conn1, conn2} { conn.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn.ReadMessage() if err != nil { t.Fatalf("Client %d failed to read message: %v", i+1, err) } if string(msg) != testMsg { t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg)) } } } func TestIntegration_BroadcastToManyClients(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() const numClients = 10 conns := make([]*websocket.Conn, numClients) for i := 0; i < numClients; i++ { conns[i] = dialWS(t, server) defer conns[i].Close() } waitForClients(t, h, numClients, time.Second) // Send from first client testMsg := "broadcast to all" if err := conns[0].WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil { t.Fatalf("Failed to send message: %v", err) } // All clients should receive it for i, conn := range conns { conn.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn.ReadMessage() if err != nil { t.Fatalf("Client %d failed to read: %v", i, err) } if string(msg) != testMsg { t.Errorf("Client %d expected %q, got %q", i, testMsg, string(msg)) } } } func TestIntegration_ClientDisconnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Disconnect client 1 conn1.Close() waitForClients(t, h, 1, time.Second) if count := h.ClientCount(); count != 1 { t.Errorf("Expected 1 client after disconnect, got %d", count) } } func TestIntegration_MessageAfterDisconnect(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Disconnect client 1 conn1.Close() waitForClients(t, h, 1, time.Second) // Send a message from client 2 — should still work testMsg := "after disconnect" if err := conn2.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil { t.Fatalf("Failed to send message: %v", err) } // Client 2 should receive its own message back conn2.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read message: %v", err) } if string(msg) != testMsg { t.Errorf("Expected %q, got %q", testMsg, string(msg)) } } func TestIntegration_MultipleMessages(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) messages := []string{"first", "second", "third"} for _, msg := range messages { if err := conn1.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { t.Fatalf("Failed to send %q: %v", msg, err) } } // Client 2 should receive all messages in order for _, expected := range messages { conn2.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read message: %v", err) } if string(msg) != expected { t.Errorf("Expected %q, got %q", expected, string(msg)) } } } func TestIntegration_ConcurrentSenders(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() const numClients = 5 conns := make([]*websocket.Conn, numClients) for i := 0; i < numClients; i++ { conns[i] = dialWS(t, server) defer conns[i].Close() } waitForClients(t, h, numClients, time.Second) // Each client sends one message concurrently var wg sync.WaitGroup for i := 0; i < numClients; i++ { wg.Add(1) go func(idx int) { defer wg.Done() msg := []byte(strings.Repeat("x", idx+1)) // unique length per sender conns[idx].WriteMessage(websocket.TextMessage, msg) }(i) } wg.Wait() // Each client should receive exactly numClients messages (one from each sender) for i, conn := range conns { received := 0 conn.SetReadDeadline(time.Now().Add(2 * time.Second)) for received < numClients { _, _, err := conn.ReadMessage() if err != nil { t.Fatalf("Client %d: read error after %d messages: %v", i, received, err) } received++ } } } func TestIntegration_GracefulShutdownClosesClients(t *testing.T) { server, h := setupTestServer(t) defer server.Close() conn := dialWS(t, server) defer conn.Close() waitForClients(t, h, 1, time.Second) // Trigger shutdown h.Shutdown() // Client should receive a close frame conn.SetReadDeadline(time.Now().Add(time.Second)) _, _, err := conn.ReadMessage() if err == nil { t.Fatal("Expected error after shutdown, got nil") } // Verify it's a close error with GoingAway code if closeErr, ok := err.(*websocket.CloseError); ok { if closeErr.Code != websocket.CloseGoingAway { t.Errorf("Expected CloseGoingAway (%d), got %d", websocket.CloseGoingAway, closeErr.Code) } } // Any error is acceptable — the key is the connection is no longer usable } func TestIntegration_EmptyMessage(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Send an empty message if err := conn1.WriteMessage(websocket.TextMessage, []byte("")); err != nil { t.Fatalf("Failed to send empty message: %v", err) } // Client 2 should receive the empty message conn2.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read message: %v", err) } if string(msg) != "" { t.Errorf("Expected empty message, got %q", string(msg)) } } func TestIntegration_LargeMessage(t *testing.T) { server, h := setupTestServer(t) defer server.Close() defer h.Shutdown() conn1 := dialWS(t, server) defer conn1.Close() conn2 := dialWS(t, server) defer conn2.Close() waitForClients(t, h, 2, time.Second) // Send a 64KB message largeMsg := strings.Repeat("A", 64*1024) if err := conn1.WriteMessage(websocket.TextMessage, []byte(largeMsg)); err != nil { t.Fatalf("Failed to send large message: %v", err) } conn2.SetReadDeadline(time.Now().Add(2 * time.Second)) _, msg, err := conn2.ReadMessage() if err != nil { t.Fatalf("Failed to read large message: %v", err) } if len(msg) != 64*1024 { t.Errorf("Expected message length %d, got %d", 64*1024, len(msg)) } }