savinmax 5bd08409dc feat(hub): extract room from URL path instead of query parameter
Change HandleWebSocket to use r.URL.Path as the room identifier instead
of r.URL.Query().Get("room"). This enables clean URL-based room routing
(e.g., ws://host/room-a) without query strings.

Update test helpers (dialTestHub, dialWSWithRoom) to connect via path
segments and fix direct broadcast channel tests to use path-style room
names (with leading slash).

All existing tests pass — clients connecting to / get the default room.

🤖 Assisted by the code-assist SOP
2026-06-13 13:22:33 +02:00

283 lines
6.8 KiB
Go

package hub
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"websocket-relay/internal/logging"
)
func newTestLogger() *logging.Logger {
return logging.NewLogger("debug", &bytes.Buffer{})
}
// dialTestHub starts an httptest server for the given hub and dials a
// WebSocket connection to it with the given room path.
// Returns the client-side connection and a cleanup function.
func dialTestHub(t *testing.T, h *Hub, room string) *websocket.Conn {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
t.Cleanup(srv.Close)
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/" + room
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to dial WebSocket: %v", err)
}
return conn
}
func TestNew(t *testing.T) {
h := New(newTestLogger())
if h == nil {
t.Fatal("New returned nil")
}
if h.rooms == nil {
t.Error("rooms map not initialized")
}
if h.connRoom == nil {
t.Error("connRoom map not initialized")
}
if h.broadcast == nil {
t.Error("broadcast channel not initialized")
}
if h.stop == nil {
t.Error("stop channel not initialized")
}
}
func TestClientCount(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients, got %d", count)
}
}
func TestRoomCount(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms, got %d", count)
}
}
func TestBroadcastChannel(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
select {
case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}:
// Channel is working
case <-time.After(100 * time.Millisecond):
t.Error("broadcast channel blocked")
}
}
func TestShutdown(t *testing.T) {
h := New(newTestLogger())
done := make(chan struct{})
go func() {
h.Run()
close(done)
}()
// Ensure Run is processing before shutdown
time.Sleep(10 * time.Millisecond)
h.Shutdown()
select {
case <-done:
// Hub.Run() returned successfully
case <-time.After(1 * time.Second):
t.Fatal("Hub.Run() did not return after Shutdown")
}
}
func TestRegisterClient(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn := dialTestHub(t, h, "test-room")
defer conn.Close()
// Allow register to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client, got %d", count)
}
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected 1 room, got %d", count)
}
}
func TestUnregisterClient(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn := dialTestHub(t, h, "test-room")
// Allow register to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client after register, got %d", count)
}
// Close the client-side connection to trigger unregister
conn.Close()
// Allow unregister to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients after unregister, got %d", count)
}
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms after last client leaves, got %d", count)
}
}
func TestRegisterMultipleRooms(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn1 := dialTestHub(t, h, "room-a")
defer conn1.Close()
conn2 := dialTestHub(t, h, "room-a")
defer conn2.Close()
conn3 := dialTestHub(t, h, "room-b")
defer conn3.Close()
// Allow all registers to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 3 {
t.Errorf("Expected 3 clients, got %d", count)
}
if count := h.RoomCount(); count != 2 {
t.Errorf("Expected 2 rooms, got %d", count)
}
}
func TestUnregisterCleansUpEmptyRoom(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn1 := dialTestHub(t, h, "shared-room")
conn2 := dialTestHub(t, h, "shared-room")
// Allow registers to be processed
time.Sleep(50 * time.Millisecond)
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected 1 room, got %d", count)
}
// Remove first client — room should still exist
conn1.Close()
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client after first disconnect, got %d", count)
}
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected room to still exist with 1 client, got %d rooms", count)
}
// Remove second client — room should be cleaned up
conn2.Close()
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients after both disconnect, got %d", count)
}
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected room to be cleaned up, got %d rooms", count)
}
}
func TestUnregisterUnknownConnNoPanic(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
// Create a raw WebSocket connection that is NOT registered with the hub
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Send directly to unregister without ever registering
h.unregister <- conn
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer conn.Close()
// Allow unregister to be processed — should not panic
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients, got %d", count)
}
}
func TestBroadcastRoomIsolation(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
connA := dialTestHub(t, h, "room-a")
defer connA.Close()
connB := dialTestHub(t, h, "room-b")
defer connB.Close()
// Allow registers to be processed
time.Sleep(50 * time.Millisecond)
// Send message to room-a via broadcast channel (room includes leading slash from URL path)
h.broadcast <- broadcastMsg{room: "/room-a", data: []byte("for-a-only")}
// Room-a client should receive the message
connA.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := connA.ReadMessage()
if err != nil {
t.Fatalf("Room-a client failed to read: %v", err)
}
if string(msg) != "for-a-only" {
t.Errorf("Expected %q, got %q", "for-a-only", string(msg))
}
// Room-b client should NOT receive the message (short timeout)
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err = connB.ReadMessage()
if err == nil {
t.Fatal("Room-b client should NOT have received room-a's message")
}
}