diff --git a/internal/hub/hub.go b/internal/hub/hub.go index d29accd..7560696 100644 --- a/internal/hub/hub.go +++ b/internal/hub/hub.go @@ -77,10 +77,11 @@ func (h *Hub) Run() { } h.rooms[c.room][c.conn] = true h.connRoom[c.conn] = c.room + count := len(h.connRoom) h.mu.Unlock() - metrics.ConnectedClients.Set(float64(len(h.connRoom))) + metrics.ConnectedClients.Inc() metrics.ConnectionsTotal.Inc() - h.logger.Infof("Client connected (room=%q). Total: %d", c.room, len(h.connRoom)) + h.logger.Infof("Client connected (room=%q). Total: %d", c.room, count) case conn := <-h.unregister: h.mu.Lock() @@ -92,12 +93,16 @@ func (h *Hub) Run() { } } delete(h.connRoom, conn) + count := len(h.connRoom) + h.mu.Unlock() + conn.Close() + metrics.ConnectedClients.Dec() + metrics.DisconnectionsTotal.Inc() + h.logger.Infof("Client disconnected (room=%q). Total: %d", room, count) + } else { + h.mu.Unlock() conn.Close() } - h.mu.Unlock() - metrics.ConnectedClients.Set(float64(len(h.connRoom))) - metrics.DisconnectionsTotal.Inc() - h.logger.Infof("Client disconnected. Total: %d", len(h.connRoom)) case message := <-h.broadcast: metrics.MessagesTotal.Inc() @@ -123,12 +128,15 @@ func (h *Hub) Run() { } } delete(h.connRoom, conn) + count := len(h.connRoom) + h.mu.Unlock() conn.Close() - metrics.ConnectedClients.Set(float64(len(h.connRoom))) + metrics.ConnectedClients.Dec() metrics.DisconnectionsTotal.Inc() - h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.connRoom)) + h.logger.Warnf("Client disconnected (write error, room=%q). Total: %d", room, count) + } else { + h.mu.Unlock() } - h.mu.Unlock() } } } diff --git a/internal/hub/hub_test.go b/internal/hub/hub_test.go index 5629a6f..dd4f6df 100644 --- a/internal/hub/hub_test.go +++ b/internal/hub/hub_test.go @@ -2,9 +2,13 @@ package hub import ( "bytes" + "net/http" + "net/http/httptest" + "strings" "testing" "time" + "github.com/gorilla/websocket" "websocket-relay/internal/logging" ) @@ -12,6 +16,22 @@ func newTestLogger() *logging.Logger { return logging.NewLogger("debug", &bytes.Buffer{}) } +// dialTestHub starts an httptest server for the given hub and dials a +// WebSocket connection to it with the given room query parameter. +// Returns the client-side connection and a cleanup function. +func dialTestHub(t *testing.T, h *Hub, room string) *websocket.Conn { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket)) + t.Cleanup(srv.Close) + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "?room=" + room + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial WebSocket: %v", err) + } + return conn +} + func TestNew(t *testing.T) { h := New(newTestLogger()) if h == nil { @@ -85,3 +105,143 @@ func TestShutdown(t *testing.T) { t.Fatal("Hub.Run() did not return after Shutdown") } } + +func TestRegisterClient(t *testing.T) { + h := New(newTestLogger()) + go h.Run() + defer h.Shutdown() + + conn := dialTestHub(t, h, "test-room") + defer conn.Close() + + // Allow register to be processed + time.Sleep(50 * time.Millisecond) + + if count := h.ClientCount(); count != 1 { + t.Errorf("Expected 1 client, got %d", count) + } + if count := h.RoomCount(); count != 1 { + t.Errorf("Expected 1 room, got %d", count) + } +} + +func TestUnregisterClient(t *testing.T) { + h := New(newTestLogger()) + go h.Run() + defer h.Shutdown() + + conn := dialTestHub(t, h, "test-room") + + // Allow register to be processed + time.Sleep(50 * time.Millisecond) + + if count := h.ClientCount(); count != 1 { + t.Errorf("Expected 1 client after register, got %d", count) + } + + // Close the client-side connection to trigger unregister + conn.Close() + + // Allow unregister to be processed + time.Sleep(50 * time.Millisecond) + + if count := h.ClientCount(); count != 0 { + t.Errorf("Expected 0 clients after unregister, got %d", count) + } + if count := h.RoomCount(); count != 0 { + t.Errorf("Expected 0 rooms after last client leaves, got %d", count) + } +} + +func TestRegisterMultipleRooms(t *testing.T) { + h := New(newTestLogger()) + go h.Run() + defer h.Shutdown() + + conn1 := dialTestHub(t, h, "room-a") + defer conn1.Close() + conn2 := dialTestHub(t, h, "room-a") + defer conn2.Close() + conn3 := dialTestHub(t, h, "room-b") + defer conn3.Close() + + // Allow all registers to be processed + time.Sleep(50 * time.Millisecond) + + if count := h.ClientCount(); count != 3 { + t.Errorf("Expected 3 clients, got %d", count) + } + if count := h.RoomCount(); count != 2 { + t.Errorf("Expected 2 rooms, got %d", count) + } +} + +func TestUnregisterCleansUpEmptyRoom(t *testing.T) { + h := New(newTestLogger()) + go h.Run() + defer h.Shutdown() + + conn1 := dialTestHub(t, h, "shared-room") + conn2 := dialTestHub(t, h, "shared-room") + + // Allow registers to be processed + time.Sleep(50 * time.Millisecond) + + if count := h.RoomCount(); count != 1 { + t.Errorf("Expected 1 room, got %d", count) + } + + // Remove first client — room should still exist + conn1.Close() + time.Sleep(50 * time.Millisecond) + + if count := h.ClientCount(); count != 1 { + t.Errorf("Expected 1 client after first disconnect, got %d", count) + } + if count := h.RoomCount(); count != 1 { + t.Errorf("Expected room to still exist with 1 client, got %d rooms", count) + } + + // Remove second client — room should be cleaned up + conn2.Close() + time.Sleep(50 * time.Millisecond) + + if count := h.ClientCount(); count != 0 { + t.Errorf("Expected 0 clients after both disconnect, got %d", count) + } + if count := h.RoomCount(); count != 0 { + t.Errorf("Expected room to be cleaned up, got %d rooms", count) + } +} + +func TestUnregisterUnknownConnNoPanic(t *testing.T) { + h := New(newTestLogger()) + go h.Run() + defer h.Shutdown() + + // Create a raw WebSocket connection that is NOT registered with the hub + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + // Send directly to unregister without ever registering + h.unregister <- conn + })) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer conn.Close() + + // Allow unregister to be processed — should not panic + time.Sleep(50 * time.Millisecond) + + if count := h.ClientCount(); count != 0 { + t.Errorf("Expected 0 clients, got %d", count) + } +}