From 03f379c73c9499c7d4f8cd5ad2108610889255bf Mon Sep 17 00:00:00 2001 From: savinmax Date: Sat, 13 Jun 2026 13:09:25 +0200 Subject: [PATCH] refactor(hub): introduce room types and update Hub struct MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add client struct with conn and room fields - Add broadcastMsg struct with room and data fields - Change Hub.clients to Hub.rooms map[string]map[*websocket.Conn]bool - Add Hub.connRoom map[*websocket.Conn]string for reverse lookup - Change broadcast channel type to chan broadcastMsg - Change register channel type to chan client - Update New() to initialize rooms and connRoom maps - Update ClientCount() to use len(h.connRoom) - Add RoomCount() method - Update Run() loop for room-segmented register/unregister/broadcast - Update HandleWebSocket to extract room from query param - Backward compatible: clients without ?room use default empty room - Update TestNew to verify rooms and connRoom maps initialized - Add TestRoomCount to verify initial room count is 0 - Fix TestBroadcastChannel to use broadcastMsg type All existing unit and integration tests pass (16 hub tests + 21 other). 🤖 Assisted by the code-assist SOP --- internal/hub/hub.go | 99 +++++++++++++++++++++++++++++----------- internal/hub/hub_test.go | 19 ++++++-- 2 files changed, 89 insertions(+), 29 deletions(-) diff --git a/internal/hub/hub.go b/internal/hub/hub.go index 0bf8f5c..d29accd 100644 --- a/internal/hub/hub.go +++ b/internal/hub/hub.go @@ -13,10 +13,23 @@ var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } +// client represents a WebSocket connection associated with a room. +type client struct { + conn *websocket.Conn + room string +} + +// broadcastMsg represents a message to be broadcast to all clients in a room. +type broadcastMsg struct { + room string + data []byte +} + type Hub struct { - clients map[*websocket.Conn]bool - broadcast chan []byte - register chan *websocket.Conn + rooms map[string]map[*websocket.Conn]bool + connRoom map[*websocket.Conn]string + broadcast chan broadcastMsg + register chan client unregister chan *websocket.Conn stop chan struct{} mu sync.RWMutex @@ -25,9 +38,10 @@ type Hub struct { func New(logger *logging.Logger) *Hub { return &Hub{ - clients: make(map[*websocket.Conn]bool), - broadcast: make(chan []byte), - register: make(chan *websocket.Conn), + rooms: make(map[string]map[*websocket.Conn]bool), + connRoom: make(map[*websocket.Conn]string), + broadcast: make(chan broadcastMsg), + register: make(chan client), unregister: make(chan *websocket.Conn), stop: make(chan struct{}), logger: logger, @@ -39,43 +53,61 @@ func (h *Hub) Run() { select { case <-h.stop: h.mu.Lock() - for conn := range h.clients { + for conn, room := range h.connRoom { conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down")) conn.Close() - delete(h.clients, conn) + if clients, ok := h.rooms[room]; ok { + delete(clients, conn) + if len(clients) == 0 { + delete(h.rooms, room) + } + } + delete(h.connRoom, conn) } h.mu.Unlock() metrics.ConnectedClients.Set(0) h.logger.Info("Hub stopped, all clients disconnected") return - case conn := <-h.register: + case c := <-h.register: h.mu.Lock() - h.clients[conn] = true + if h.rooms[c.room] == nil { + h.rooms[c.room] = make(map[*websocket.Conn]bool) + } + h.rooms[c.room][c.conn] = true + h.connRoom[c.conn] = c.room h.mu.Unlock() - metrics.ConnectedClients.Set(float64(len(h.clients))) + metrics.ConnectedClients.Set(float64(len(h.connRoom))) metrics.ConnectionsTotal.Inc() - h.logger.Infof("Client connected. Total: %d", len(h.clients)) + h.logger.Infof("Client connected (room=%q). Total: %d", c.room, len(h.connRoom)) case conn := <-h.unregister: h.mu.Lock() - if _, ok := h.clients[conn]; ok { - delete(h.clients, conn) + if room, ok := h.connRoom[conn]; ok { + if clients, ok := h.rooms[room]; ok { + delete(clients, conn) + if len(clients) == 0 { + delete(h.rooms, room) + } + } + delete(h.connRoom, conn) conn.Close() } h.mu.Unlock() - metrics.ConnectedClients.Set(float64(len(h.clients))) + metrics.ConnectedClients.Set(float64(len(h.connRoom))) metrics.DisconnectionsTotal.Inc() - h.logger.Infof("Client disconnected. Total: %d", len(h.clients)) + h.logger.Infof("Client disconnected. Total: %d", len(h.connRoom)) case message := <-h.broadcast: metrics.MessagesTotal.Inc() h.mu.RLock() var failed []*websocket.Conn - for conn := range h.clients { - if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { - failed = append(failed, conn) + if clients, ok := h.rooms[message.room]; ok { + for conn := range clients { + if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil { + failed = append(failed, conn) + } } } h.mu.RUnlock() @@ -83,12 +115,18 @@ func (h *Hub) Run() { // Remove failed clients properly so metrics stay consistent for _, conn := range failed { h.mu.Lock() - if _, ok := h.clients[conn]; ok { - delete(h.clients, conn) + if room, ok := h.connRoom[conn]; ok { + if clients, ok := h.rooms[room]; ok { + delete(clients, conn) + if len(clients) == 0 { + delete(h.rooms, room) + } + } + delete(h.connRoom, conn) conn.Close() - metrics.ConnectedClients.Set(float64(len(h.clients))) + metrics.ConnectedClients.Set(float64(len(h.connRoom))) metrics.DisconnectionsTotal.Inc() - h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients)) + h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.connRoom)) } h.mu.Unlock() } @@ -108,7 +146,9 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { return } - h.register <- conn + room := r.URL.Query().Get("room") + + h.register <- client{conn: conn, room: room} go func() { defer func() { @@ -120,7 +160,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { if err != nil { break } - h.broadcast <- message + h.broadcast <- broadcastMsg{room: room, data: message} } }() } @@ -128,5 +168,12 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() - return len(h.clients) + return len(h.connRoom) +} + +// RoomCount returns the number of active rooms. +func (h *Hub) RoomCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.rooms) } diff --git a/internal/hub/hub_test.go b/internal/hub/hub_test.go index c40a705..5629a6f 100644 --- a/internal/hub/hub_test.go +++ b/internal/hub/hub_test.go @@ -17,8 +17,11 @@ func TestNew(t *testing.T) { if h == nil { t.Fatal("New returned nil") } - if h.clients == nil { - t.Error("clients map not initialized") + if h.rooms == nil { + t.Error("rooms map not initialized") + } + if h.connRoom == nil { + t.Error("connRoom map not initialized") } if h.broadcast == nil { t.Error("broadcast channel not initialized") @@ -38,13 +41,23 @@ func TestClientCount(t *testing.T) { } } +func TestRoomCount(t *testing.T) { + h := New(newTestLogger()) + go h.Run() + defer h.Shutdown() + + if count := h.RoomCount(); count != 0 { + t.Errorf("Expected 0 rooms, got %d", count) + } +} + func TestBroadcastChannel(t *testing.T) { h := New(newTestLogger()) go h.Run() defer h.Shutdown() select { - case h.broadcast <- []byte("test"): + case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}: // Channel is working case <-time.After(100 * time.Millisecond): t.Error("broadcast channel blocked")