refactor(hub): introduce room types and update Hub struct

- Add client struct with conn and room fields
- Add broadcastMsg struct with room and data fields
- Change Hub.clients to Hub.rooms map[string]map[*websocket.Conn]bool
- Add Hub.connRoom map[*websocket.Conn]string for reverse lookup
- Change broadcast channel type to chan broadcastMsg
- Change register channel type to chan client
- Update New() to initialize rooms and connRoom maps
- Update ClientCount() to use len(h.connRoom)
- Add RoomCount() method
- Update Run() loop for room-segmented register/unregister/broadcast
- Update HandleWebSocket to extract room from query param
- Backward compatible: clients without ?room use default empty room
- Update TestNew to verify rooms and connRoom maps initialized
- Add TestRoomCount to verify initial room count is 0
- Fix TestBroadcastChannel to use broadcastMsg type

All existing unit and integration tests pass (16 hub tests + 21 other).

🤖 Assisted by the code-assist SOP
This commit is contained in:
savinmax 2026-06-13 13:09:25 +02:00
parent a40d16dae0
commit 03f379c73c
2 changed files with 89 additions and 29 deletions

View File

@ -13,10 +13,23 @@ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
// client represents a WebSocket connection associated with a room.
type client struct {
conn *websocket.Conn
room string
}
// broadcastMsg represents a message to be broadcast to all clients in a room.
type broadcastMsg struct {
room string
data []byte
}
type Hub struct {
clients map[*websocket.Conn]bool
broadcast chan []byte
register chan *websocket.Conn
rooms map[string]map[*websocket.Conn]bool
connRoom map[*websocket.Conn]string
broadcast chan broadcastMsg
register chan client
unregister chan *websocket.Conn
stop chan struct{}
mu sync.RWMutex
@ -25,9 +38,10 @@ type Hub struct {
func New(logger *logging.Logger) *Hub {
return &Hub{
clients: make(map[*websocket.Conn]bool),
broadcast: make(chan []byte),
register: make(chan *websocket.Conn),
rooms: make(map[string]map[*websocket.Conn]bool),
connRoom: make(map[*websocket.Conn]string),
broadcast: make(chan broadcastMsg),
register: make(chan client),
unregister: make(chan *websocket.Conn),
stop: make(chan struct{}),
logger: logger,
@ -39,43 +53,61 @@ func (h *Hub) Run() {
select {
case <-h.stop:
h.mu.Lock()
for conn := range h.clients {
for conn, room := range h.connRoom {
conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
conn.Close()
delete(h.clients, conn)
if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
}
h.mu.Unlock()
metrics.ConnectedClients.Set(0)
h.logger.Info("Hub stopped, all clients disconnected")
return
case conn := <-h.register:
case c := <-h.register:
h.mu.Lock()
h.clients[conn] = true
if h.rooms[c.room] == nil {
h.rooms[c.room] = make(map[*websocket.Conn]bool)
}
h.rooms[c.room][c.conn] = true
h.connRoom[c.conn] = c.room
h.mu.Unlock()
metrics.ConnectedClients.Set(float64(len(h.clients)))
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
metrics.ConnectionsTotal.Inc()
h.logger.Infof("Client connected. Total: %d", len(h.clients))
h.logger.Infof("Client connected (room=%q). Total: %d", c.room, len(h.connRoom))
case conn := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[conn]; ok {
delete(h.clients, conn)
if room, ok := h.connRoom[conn]; ok {
if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
conn.Close()
}
h.mu.Unlock()
metrics.ConnectedClients.Set(float64(len(h.clients)))
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
metrics.DisconnectionsTotal.Inc()
h.logger.Infof("Client disconnected. Total: %d", len(h.clients))
h.logger.Infof("Client disconnected. Total: %d", len(h.connRoom))
case message := <-h.broadcast:
metrics.MessagesTotal.Inc()
h.mu.RLock()
var failed []*websocket.Conn
for conn := range h.clients {
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
failed = append(failed, conn)
if clients, ok := h.rooms[message.room]; ok {
for conn := range clients {
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
failed = append(failed, conn)
}
}
}
h.mu.RUnlock()
@ -83,12 +115,18 @@ func (h *Hub) Run() {
// Remove failed clients properly so metrics stay consistent
for _, conn := range failed {
h.mu.Lock()
if _, ok := h.clients[conn]; ok {
delete(h.clients, conn)
if room, ok := h.connRoom[conn]; ok {
if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
conn.Close()
metrics.ConnectedClients.Set(float64(len(h.clients)))
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
metrics.DisconnectionsTotal.Inc()
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients))
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.connRoom))
}
h.mu.Unlock()
}
@ -108,7 +146,9 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
return
}
h.register <- conn
room := r.URL.Query().Get("room")
h.register <- client{conn: conn, room: room}
go func() {
defer func() {
@ -120,7 +160,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
if err != nil {
break
}
h.broadcast <- message
h.broadcast <- broadcastMsg{room: room, data: message}
}
}()
}
@ -128,5 +168,12 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
func (h *Hub) ClientCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
return len(h.connRoom)
}
// RoomCount returns the number of active rooms.
func (h *Hub) RoomCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.rooms)
}

View File

@ -17,8 +17,11 @@ func TestNew(t *testing.T) {
if h == nil {
t.Fatal("New returned nil")
}
if h.clients == nil {
t.Error("clients map not initialized")
if h.rooms == nil {
t.Error("rooms map not initialized")
}
if h.connRoom == nil {
t.Error("connRoom map not initialized")
}
if h.broadcast == nil {
t.Error("broadcast channel not initialized")
@ -38,13 +41,23 @@ func TestClientCount(t *testing.T) {
}
}
func TestRoomCount(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms, got %d", count)
}
}
func TestBroadcastChannel(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
select {
case h.broadcast <- []byte("test"):
case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}:
// Channel is working
case <-time.After(100 * time.Millisecond):
t.Error("broadcast channel blocked")