diff --git a/internal/hub/hub.go b/internal/hub/hub.go index 0bf8f5c..d29accd 100644 --- a/internal/hub/hub.go +++ b/internal/hub/hub.go @@ -13,10 +13,23 @@ var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } +// client represents a WebSocket connection associated with a room. +type client struct { + conn *websocket.Conn + room string +} + +// broadcastMsg represents a message to be broadcast to all clients in a room. +type broadcastMsg struct { + room string + data []byte +} + type Hub struct { - clients map[*websocket.Conn]bool - broadcast chan []byte - register chan *websocket.Conn + rooms map[string]map[*websocket.Conn]bool + connRoom map[*websocket.Conn]string + broadcast chan broadcastMsg + register chan client unregister chan *websocket.Conn stop chan struct{} mu sync.RWMutex @@ -25,9 +38,10 @@ type Hub struct { func New(logger *logging.Logger) *Hub { return &Hub{ - clients: make(map[*websocket.Conn]bool), - broadcast: make(chan []byte), - register: make(chan *websocket.Conn), + rooms: make(map[string]map[*websocket.Conn]bool), + connRoom: make(map[*websocket.Conn]string), + broadcast: make(chan broadcastMsg), + register: make(chan client), unregister: make(chan *websocket.Conn), stop: make(chan struct{}), logger: logger, @@ -39,43 +53,61 @@ func (h *Hub) Run() { select { case <-h.stop: h.mu.Lock() - for conn := range h.clients { + for conn, room := range h.connRoom { conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down")) conn.Close() - delete(h.clients, conn) + if clients, ok := h.rooms[room]; ok { + delete(clients, conn) + if len(clients) == 0 { + delete(h.rooms, room) + } + } + delete(h.connRoom, conn) } h.mu.Unlock() metrics.ConnectedClients.Set(0) h.logger.Info("Hub stopped, all clients disconnected") return - case conn := <-h.register: + case c := <-h.register: h.mu.Lock() - h.clients[conn] = true + if h.rooms[c.room] == nil { + h.rooms[c.room] = make(map[*websocket.Conn]bool) + } + h.rooms[c.room][c.conn] = true + h.connRoom[c.conn] = c.room h.mu.Unlock() - metrics.ConnectedClients.Set(float64(len(h.clients))) + metrics.ConnectedClients.Set(float64(len(h.connRoom))) metrics.ConnectionsTotal.Inc() - h.logger.Infof("Client connected. Total: %d", len(h.clients)) + h.logger.Infof("Client connected (room=%q). Total: %d", c.room, len(h.connRoom)) case conn := <-h.unregister: h.mu.Lock() - if _, ok := h.clients[conn]; ok { - delete(h.clients, conn) + if room, ok := h.connRoom[conn]; ok { + if clients, ok := h.rooms[room]; ok { + delete(clients, conn) + if len(clients) == 0 { + delete(h.rooms, room) + } + } + delete(h.connRoom, conn) conn.Close() } h.mu.Unlock() - metrics.ConnectedClients.Set(float64(len(h.clients))) + metrics.ConnectedClients.Set(float64(len(h.connRoom))) metrics.DisconnectionsTotal.Inc() - h.logger.Infof("Client disconnected. Total: %d", len(h.clients)) + h.logger.Infof("Client disconnected. Total: %d", len(h.connRoom)) case message := <-h.broadcast: metrics.MessagesTotal.Inc() h.mu.RLock() var failed []*websocket.Conn - for conn := range h.clients { - if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { - failed = append(failed, conn) + if clients, ok := h.rooms[message.room]; ok { + for conn := range clients { + if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil { + failed = append(failed, conn) + } } } h.mu.RUnlock() @@ -83,12 +115,18 @@ func (h *Hub) Run() { // Remove failed clients properly so metrics stay consistent for _, conn := range failed { h.mu.Lock() - if _, ok := h.clients[conn]; ok { - delete(h.clients, conn) + if room, ok := h.connRoom[conn]; ok { + if clients, ok := h.rooms[room]; ok { + delete(clients, conn) + if len(clients) == 0 { + delete(h.rooms, room) + } + } + delete(h.connRoom, conn) conn.Close() - metrics.ConnectedClients.Set(float64(len(h.clients))) + metrics.ConnectedClients.Set(float64(len(h.connRoom))) metrics.DisconnectionsTotal.Inc() - h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients)) + h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.connRoom)) } h.mu.Unlock() } @@ -108,7 +146,9 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { return } - h.register <- conn + room := r.URL.Query().Get("room") + + h.register <- client{conn: conn, room: room} go func() { defer func() { @@ -120,7 +160,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { if err != nil { break } - h.broadcast <- message + h.broadcast <- broadcastMsg{room: room, data: message} } }() } @@ -128,5 +168,12 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() - return len(h.clients) + return len(h.connRoom) +} + +// RoomCount returns the number of active rooms. +func (h *Hub) RoomCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.rooms) } diff --git a/internal/hub/hub_test.go b/internal/hub/hub_test.go index c40a705..5629a6f 100644 --- a/internal/hub/hub_test.go +++ b/internal/hub/hub_test.go @@ -17,8 +17,11 @@ func TestNew(t *testing.T) { if h == nil { t.Fatal("New returned nil") } - if h.clients == nil { - t.Error("clients map not initialized") + if h.rooms == nil { + t.Error("rooms map not initialized") + } + if h.connRoom == nil { + t.Error("connRoom map not initialized") } if h.broadcast == nil { t.Error("broadcast channel not initialized") @@ -38,13 +41,23 @@ func TestClientCount(t *testing.T) { } } +func TestRoomCount(t *testing.T) { + h := New(newTestLogger()) + go h.Run() + defer h.Shutdown() + + if count := h.RoomCount(); count != 0 { + t.Errorf("Expected 0 rooms, got %d", count) + } +} + func TestBroadcastChannel(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() select { - case h.broadcast <- []byte("test"): + case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}: // Channel is working case <-time.After(100 * time.Millisecond): t.Error("broadcast channel blocked")