refactor(hub): introduce room types and update Hub struct

- Add client struct with conn and room fields
- Add broadcastMsg struct with room and data fields
- Change Hub.clients to Hub.rooms map[string]map[*websocket.Conn]bool
- Add Hub.connRoom map[*websocket.Conn]string for reverse lookup
- Change broadcast channel type to chan broadcastMsg
- Change register channel type to chan client
- Update New() to initialize rooms and connRoom maps
- Update ClientCount() to use len(h.connRoom)
- Add RoomCount() method
- Update Run() loop for room-segmented register/unregister/broadcast
- Update HandleWebSocket to extract room from query param
- Backward compatible: clients without ?room use default empty room
- Update TestNew to verify rooms and connRoom maps initialized
- Add TestRoomCount to verify initial room count is 0
- Fix TestBroadcastChannel to use broadcastMsg type

All existing unit and integration tests pass (16 hub tests + 21 other).

🤖 Assisted by the code-assist SOP
This commit is contained in:
savinmax 2026-06-13 13:09:25 +02:00
parent a40d16dae0
commit 03f379c73c
2 changed files with 89 additions and 29 deletions

View File

@ -13,10 +13,23 @@ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true }, CheckOrigin: func(r *http.Request) bool { return true },
} }
// client represents a WebSocket connection associated with a room.
type client struct {
conn *websocket.Conn
room string
}
// broadcastMsg represents a message to be broadcast to all clients in a room.
type broadcastMsg struct {
room string
data []byte
}
type Hub struct { type Hub struct {
clients map[*websocket.Conn]bool rooms map[string]map[*websocket.Conn]bool
broadcast chan []byte connRoom map[*websocket.Conn]string
register chan *websocket.Conn broadcast chan broadcastMsg
register chan client
unregister chan *websocket.Conn unregister chan *websocket.Conn
stop chan struct{} stop chan struct{}
mu sync.RWMutex mu sync.RWMutex
@ -25,9 +38,10 @@ type Hub struct {
func New(logger *logging.Logger) *Hub { func New(logger *logging.Logger) *Hub {
return &Hub{ return &Hub{
clients: make(map[*websocket.Conn]bool), rooms: make(map[string]map[*websocket.Conn]bool),
broadcast: make(chan []byte), connRoom: make(map[*websocket.Conn]string),
register: make(chan *websocket.Conn), broadcast: make(chan broadcastMsg),
register: make(chan client),
unregister: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn),
stop: make(chan struct{}), stop: make(chan struct{}),
logger: logger, logger: logger,
@ -39,56 +53,80 @@ func (h *Hub) Run() {
select { select {
case <-h.stop: case <-h.stop:
h.mu.Lock() h.mu.Lock()
for conn := range h.clients { for conn, room := range h.connRoom {
conn.WriteMessage(websocket.CloseMessage, conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down")) websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
conn.Close() conn.Close()
delete(h.clients, conn) if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
} }
h.mu.Unlock() h.mu.Unlock()
metrics.ConnectedClients.Set(0) metrics.ConnectedClients.Set(0)
h.logger.Info("Hub stopped, all clients disconnected") h.logger.Info("Hub stopped, all clients disconnected")
return return
case conn := <-h.register: case c := <-h.register:
h.mu.Lock() h.mu.Lock()
h.clients[conn] = true if h.rooms[c.room] == nil {
h.rooms[c.room] = make(map[*websocket.Conn]bool)
}
h.rooms[c.room][c.conn] = true
h.connRoom[c.conn] = c.room
h.mu.Unlock() h.mu.Unlock()
metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.ConnectedClients.Set(float64(len(h.connRoom)))
metrics.ConnectionsTotal.Inc() metrics.ConnectionsTotal.Inc()
h.logger.Infof("Client connected. Total: %d", len(h.clients)) h.logger.Infof("Client connected (room=%q). Total: %d", c.room, len(h.connRoom))
case conn := <-h.unregister: case conn := <-h.unregister:
h.mu.Lock() h.mu.Lock()
if _, ok := h.clients[conn]; ok { if room, ok := h.connRoom[conn]; ok {
delete(h.clients, conn) if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
conn.Close() conn.Close()
} }
h.mu.Unlock() h.mu.Unlock()
metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.ConnectedClients.Set(float64(len(h.connRoom)))
metrics.DisconnectionsTotal.Inc() metrics.DisconnectionsTotal.Inc()
h.logger.Infof("Client disconnected. Total: %d", len(h.clients)) h.logger.Infof("Client disconnected. Total: %d", len(h.connRoom))
case message := <-h.broadcast: case message := <-h.broadcast:
metrics.MessagesTotal.Inc() metrics.MessagesTotal.Inc()
h.mu.RLock() h.mu.RLock()
var failed []*websocket.Conn var failed []*websocket.Conn
for conn := range h.clients { if clients, ok := h.rooms[message.room]; ok {
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { for conn := range clients {
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
failed = append(failed, conn) failed = append(failed, conn)
} }
} }
}
h.mu.RUnlock() h.mu.RUnlock()
// Remove failed clients properly so metrics stay consistent // Remove failed clients properly so metrics stay consistent
for _, conn := range failed { for _, conn := range failed {
h.mu.Lock() h.mu.Lock()
if _, ok := h.clients[conn]; ok { if room, ok := h.connRoom[conn]; ok {
delete(h.clients, conn) if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
conn.Close() conn.Close()
metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.ConnectedClients.Set(float64(len(h.connRoom)))
metrics.DisconnectionsTotal.Inc() metrics.DisconnectionsTotal.Inc()
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients)) h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.connRoom))
} }
h.mu.Unlock() h.mu.Unlock()
} }
@ -108,7 +146,9 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
return return
} }
h.register <- conn room := r.URL.Query().Get("room")
h.register <- client{conn: conn, room: room}
go func() { go func() {
defer func() { defer func() {
@ -120,7 +160,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
break break
} }
h.broadcast <- message h.broadcast <- broadcastMsg{room: room, data: message}
} }
}() }()
} }
@ -128,5 +168,12 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
func (h *Hub) ClientCount() int { func (h *Hub) ClientCount() int {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
return len(h.clients) return len(h.connRoom)
}
// RoomCount returns the number of active rooms.
func (h *Hub) RoomCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.rooms)
} }

View File

@ -17,8 +17,11 @@ func TestNew(t *testing.T) {
if h == nil { if h == nil {
t.Fatal("New returned nil") t.Fatal("New returned nil")
} }
if h.clients == nil { if h.rooms == nil {
t.Error("clients map not initialized") t.Error("rooms map not initialized")
}
if h.connRoom == nil {
t.Error("connRoom map not initialized")
} }
if h.broadcast == nil { if h.broadcast == nil {
t.Error("broadcast channel not initialized") t.Error("broadcast channel not initialized")
@ -38,13 +41,23 @@ func TestClientCount(t *testing.T) {
} }
} }
func TestRoomCount(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms, got %d", count)
}
}
func TestBroadcastChannel(t *testing.T) { func TestBroadcastChannel(t *testing.T) {
h := New(newTestLogger()) h := New(newTestLogger())
go h.Run() go h.Run()
defer h.Shutdown() defer h.Shutdown()
select { select {
case h.broadcast <- []byte("test"): case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}:
// Channel is working // Channel is working
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
t.Error("broadcast channel blocked") t.Error("broadcast channel blocked")