refactor(hub): introduce room types and update Hub struct
- Add client struct with conn and room fields
- Add broadcastMsg struct with room and data fields
- Change Hub.clients to Hub.rooms map[string]map[*websocket.Conn]bool
- Add Hub.connRoom map[*websocket.Conn]string for reverse lookup
- Change broadcast channel type to chan broadcastMsg
- Change register channel type to chan client
- Update New() to initialize rooms and connRoom maps
- Update ClientCount() to use len(h.connRoom)
- Add RoomCount() method
- Update Run() loop for room-segmented register/unregister/broadcast
- Update HandleWebSocket to extract room from query param
- Backward compatible: clients without ?room use default empty room
- Update TestNew to verify rooms and connRoom maps initialized
- Add TestRoomCount to verify initial room count is 0
- Fix TestBroadcastChannel to use broadcastMsg type
All existing unit and integration tests pass (16 hub tests + 21 other).
🤖 Assisted by the code-assist SOP
This commit is contained in:
parent
a40d16dae0
commit
03f379c73c
@ -13,10 +13,23 @@ var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// client represents a WebSocket connection associated with a room.
|
||||
type client struct {
|
||||
conn *websocket.Conn
|
||||
room string
|
||||
}
|
||||
|
||||
// broadcastMsg represents a message to be broadcast to all clients in a room.
|
||||
type broadcastMsg struct {
|
||||
room string
|
||||
data []byte
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
clients map[*websocket.Conn]bool
|
||||
broadcast chan []byte
|
||||
register chan *websocket.Conn
|
||||
rooms map[string]map[*websocket.Conn]bool
|
||||
connRoom map[*websocket.Conn]string
|
||||
broadcast chan broadcastMsg
|
||||
register chan client
|
||||
unregister chan *websocket.Conn
|
||||
stop chan struct{}
|
||||
mu sync.RWMutex
|
||||
@ -25,9 +38,10 @@ type Hub struct {
|
||||
|
||||
func New(logger *logging.Logger) *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
broadcast: make(chan []byte),
|
||||
register: make(chan *websocket.Conn),
|
||||
rooms: make(map[string]map[*websocket.Conn]bool),
|
||||
connRoom: make(map[*websocket.Conn]string),
|
||||
broadcast: make(chan broadcastMsg),
|
||||
register: make(chan client),
|
||||
unregister: make(chan *websocket.Conn),
|
||||
stop: make(chan struct{}),
|
||||
logger: logger,
|
||||
@ -39,43 +53,61 @@ func (h *Hub) Run() {
|
||||
select {
|
||||
case <-h.stop:
|
||||
h.mu.Lock()
|
||||
for conn := range h.clients {
|
||||
for conn, room := range h.connRoom {
|
||||
conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
|
||||
conn.Close()
|
||||
delete(h.clients, conn)
|
||||
if clients, ok := h.rooms[room]; ok {
|
||||
delete(clients, conn)
|
||||
if len(clients) == 0 {
|
||||
delete(h.rooms, room)
|
||||
}
|
||||
}
|
||||
delete(h.connRoom, conn)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(0)
|
||||
h.logger.Info("Hub stopped, all clients disconnected")
|
||||
return
|
||||
|
||||
case conn := <-h.register:
|
||||
case c := <-h.register:
|
||||
h.mu.Lock()
|
||||
h.clients[conn] = true
|
||||
if h.rooms[c.room] == nil {
|
||||
h.rooms[c.room] = make(map[*websocket.Conn]bool)
|
||||
}
|
||||
h.rooms[c.room][c.conn] = true
|
||||
h.connRoom[c.conn] = c.room
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
|
||||
metrics.ConnectionsTotal.Inc()
|
||||
h.logger.Infof("Client connected. Total: %d", len(h.clients))
|
||||
h.logger.Infof("Client connected (room=%q). Total: %d", c.room, len(h.connRoom))
|
||||
|
||||
case conn := <-h.unregister:
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[conn]; ok {
|
||||
delete(h.clients, conn)
|
||||
if room, ok := h.connRoom[conn]; ok {
|
||||
if clients, ok := h.rooms[room]; ok {
|
||||
delete(clients, conn)
|
||||
if len(clients) == 0 {
|
||||
delete(h.rooms, room)
|
||||
}
|
||||
}
|
||||
delete(h.connRoom, conn)
|
||||
conn.Close()
|
||||
}
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
|
||||
metrics.DisconnectionsTotal.Inc()
|
||||
h.logger.Infof("Client disconnected. Total: %d", len(h.clients))
|
||||
h.logger.Infof("Client disconnected. Total: %d", len(h.connRoom))
|
||||
|
||||
case message := <-h.broadcast:
|
||||
metrics.MessagesTotal.Inc()
|
||||
h.mu.RLock()
|
||||
var failed []*websocket.Conn
|
||||
for conn := range h.clients {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
failed = append(failed, conn)
|
||||
if clients, ok := h.rooms[message.room]; ok {
|
||||
for conn := range clients {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
|
||||
failed = append(failed, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
@ -83,12 +115,18 @@ func (h *Hub) Run() {
|
||||
// Remove failed clients properly so metrics stay consistent
|
||||
for _, conn := range failed {
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[conn]; ok {
|
||||
delete(h.clients, conn)
|
||||
if room, ok := h.connRoom[conn]; ok {
|
||||
if clients, ok := h.rooms[room]; ok {
|
||||
delete(clients, conn)
|
||||
if len(clients) == 0 {
|
||||
delete(h.rooms, room)
|
||||
}
|
||||
}
|
||||
delete(h.connRoom, conn)
|
||||
conn.Close()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
|
||||
metrics.DisconnectionsTotal.Inc()
|
||||
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients))
|
||||
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.connRoom))
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
@ -108,7 +146,9 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.register <- conn
|
||||
room := r.URL.Query().Get("room")
|
||||
|
||||
h.register <- client{conn: conn, room: room}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
@ -120,7 +160,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
h.broadcast <- message
|
||||
h.broadcast <- broadcastMsg{room: room, data: message}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -128,5 +168,12 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Hub) ClientCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.clients)
|
||||
return len(h.connRoom)
|
||||
}
|
||||
|
||||
// RoomCount returns the number of active rooms.
|
||||
func (h *Hub) RoomCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.rooms)
|
||||
}
|
||||
|
||||
@ -17,8 +17,11 @@ func TestNew(t *testing.T) {
|
||||
if h == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
if h.clients == nil {
|
||||
t.Error("clients map not initialized")
|
||||
if h.rooms == nil {
|
||||
t.Error("rooms map not initialized")
|
||||
}
|
||||
if h.connRoom == nil {
|
||||
t.Error("connRoom map not initialized")
|
||||
}
|
||||
if h.broadcast == nil {
|
||||
t.Error("broadcast channel not initialized")
|
||||
@ -38,13 +41,23 @@ func TestClientCount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomCount(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
if count := h.RoomCount(); count != 0 {
|
||||
t.Errorf("Expected 0 rooms, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastChannel(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
select {
|
||||
case h.broadcast <- []byte("test"):
|
||||
case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}:
|
||||
// Channel is working
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("broadcast channel blocked")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user