package hub import ( "bytes" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gorilla/websocket" "websocket-relay/internal/logging" ) func newTestLogger() *logging.Logger { return logging.NewLogger("debug", &bytes.Buffer{}) } // dialTestHub starts an httptest server for the given hub and dials a // WebSocket connection to it with the given room query parameter. // Returns the client-side connection and a cleanup function. func dialTestHub(t *testing.T, h *Hub, room string) *websocket.Conn { t.Helper() srv := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket)) t.Cleanup(srv.Close) wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "?room=" + room conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("failed to dial WebSocket: %v", err) } return conn } func TestNew(t *testing.T) { h := New(newTestLogger()) if h == nil { t.Fatal("New returned nil") } if h.rooms == nil { t.Error("rooms map not initialized") } if h.connRoom == nil { t.Error("connRoom map not initialized") } if h.broadcast == nil { t.Error("broadcast channel not initialized") } if h.stop == nil { t.Error("stop channel not initialized") } } func TestClientCount(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() if count := h.ClientCount(); count != 0 { t.Errorf("Expected 0 clients, got %d", count) } } func TestRoomCount(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() if count := h.RoomCount(); count != 0 { t.Errorf("Expected 0 rooms, got %d", count) } } func TestBroadcastChannel(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() select { case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}: // Channel is working case <-time.After(100 * time.Millisecond): t.Error("broadcast channel blocked") } } func TestShutdown(t *testing.T) { h := New(newTestLogger()) done := make(chan struct{}) go func() { h.Run() close(done) }() // Ensure Run is processing before shutdown time.Sleep(10 * time.Millisecond) h.Shutdown() select { case <-done: // Hub.Run() returned successfully case <-time.After(1 * time.Second): t.Fatal("Hub.Run() did not return after Shutdown") } } func TestRegisterClient(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() conn := dialTestHub(t, h, "test-room") defer conn.Close() // Allow register to be processed time.Sleep(50 * time.Millisecond) if count := h.ClientCount(); count != 1 { t.Errorf("Expected 1 client, got %d", count) } if count := h.RoomCount(); count != 1 { t.Errorf("Expected 1 room, got %d", count) } } func TestUnregisterClient(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() conn := dialTestHub(t, h, "test-room") // Allow register to be processed time.Sleep(50 * time.Millisecond) if count := h.ClientCount(); count != 1 { t.Errorf("Expected 1 client after register, got %d", count) } // Close the client-side connection to trigger unregister conn.Close() // Allow unregister to be processed time.Sleep(50 * time.Millisecond) if count := h.ClientCount(); count != 0 { t.Errorf("Expected 0 clients after unregister, got %d", count) } if count := h.RoomCount(); count != 0 { t.Errorf("Expected 0 rooms after last client leaves, got %d", count) } } func TestRegisterMultipleRooms(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() conn1 := dialTestHub(t, h, "room-a") defer conn1.Close() conn2 := dialTestHub(t, h, "room-a") defer conn2.Close() conn3 := dialTestHub(t, h, "room-b") defer conn3.Close() // Allow all registers to be processed time.Sleep(50 * time.Millisecond) if count := h.ClientCount(); count != 3 { t.Errorf("Expected 3 clients, got %d", count) } if count := h.RoomCount(); count != 2 { t.Errorf("Expected 2 rooms, got %d", count) } } func TestUnregisterCleansUpEmptyRoom(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() conn1 := dialTestHub(t, h, "shared-room") conn2 := dialTestHub(t, h, "shared-room") // Allow registers to be processed time.Sleep(50 * time.Millisecond) if count := h.RoomCount(); count != 1 { t.Errorf("Expected 1 room, got %d", count) } // Remove first client — room should still exist conn1.Close() time.Sleep(50 * time.Millisecond) if count := h.ClientCount(); count != 1 { t.Errorf("Expected 1 client after first disconnect, got %d", count) } if count := h.RoomCount(); count != 1 { t.Errorf("Expected room to still exist with 1 client, got %d rooms", count) } // Remove second client — room should be cleaned up conn2.Close() time.Sleep(50 * time.Millisecond) if count := h.ClientCount(); count != 0 { t.Errorf("Expected 0 clients after both disconnect, got %d", count) } if count := h.RoomCount(); count != 0 { t.Errorf("Expected room to be cleaned up, got %d rooms", count) } } func TestUnregisterUnknownConnNoPanic(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() // Create a raw WebSocket connection that is NOT registered with the hub srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } // Send directly to unregister without ever registering h.unregister <- conn })) defer srv.Close() wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("failed to dial: %v", err) } defer conn.Close() // Allow unregister to be processed — should not panic time.Sleep(50 * time.Millisecond) if count := h.ClientCount(); count != 0 { t.Errorf("Expected 0 clients, got %d", count) } } func TestBroadcastRoomIsolation(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() connA := dialTestHub(t, h, "room-a") defer connA.Close() connB := dialTestHub(t, h, "room-b") defer connB.Close() // Allow registers to be processed time.Sleep(50 * time.Millisecond) // Send message to room-a via broadcast channel h.broadcast <- broadcastMsg{room: "room-a", data: []byte("for-a-only")} // Room-a client should receive the message connA.SetReadDeadline(time.Now().Add(time.Second)) _, msg, err := connA.ReadMessage() if err != nil { t.Fatalf("Room-a client failed to read: %v", err) } if string(msg) != "for-a-only" { t.Errorf("Expected %q, got %q", "for-a-only", string(msg)) } // Room-b client should NOT receive the message (short timeout) connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) _, _, err = connB.ReadMessage() if err == nil { t.Fatal("Room-b client should NOT have received room-a's message") } }