refactor(hub): introduce room types and update Hub struct
- Add client struct with conn and room fields
- Add broadcastMsg struct with room and data fields
- Change Hub.clients to Hub.rooms map[string]map[*websocket.Conn]bool
- Add Hub.connRoom map[*websocket.Conn]string for reverse lookup
- Change broadcast channel type to chan broadcastMsg
- Change register channel type to chan client
- Update New() to initialize rooms and connRoom maps
- Update ClientCount() to use len(h.connRoom)
- Add RoomCount() method
- Update Run() loop for room-segmented register/unregister/broadcast
- Update HandleWebSocket to extract room from query param
- Backward compatible: clients without ?room use default empty room
- Update TestNew to verify rooms and connRoom maps initialized
- Add TestRoomCount to verify initial room count is 0
- Fix TestBroadcastChannel to use broadcastMsg type
All existing unit and integration tests pass (16 hub tests + 21 other).
🤖 Assisted by the code-assist SOP
This commit is contained in:
parent
a40d16dae0
commit
03f379c73c
@ -13,10 +13,23 @@ var upgrader = websocket.Upgrader{
|
|||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// client represents a WebSocket connection associated with a room.
|
||||||
|
type client struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
room string
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcastMsg represents a message to be broadcast to all clients in a room.
|
||||||
|
type broadcastMsg struct {
|
||||||
|
room string
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
type Hub struct {
|
type Hub struct {
|
||||||
clients map[*websocket.Conn]bool
|
rooms map[string]map[*websocket.Conn]bool
|
||||||
broadcast chan []byte
|
connRoom map[*websocket.Conn]string
|
||||||
register chan *websocket.Conn
|
broadcast chan broadcastMsg
|
||||||
|
register chan client
|
||||||
unregister chan *websocket.Conn
|
unregister chan *websocket.Conn
|
||||||
stop chan struct{}
|
stop chan struct{}
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@ -25,9 +38,10 @@ type Hub struct {
|
|||||||
|
|
||||||
func New(logger *logging.Logger) *Hub {
|
func New(logger *logging.Logger) *Hub {
|
||||||
return &Hub{
|
return &Hub{
|
||||||
clients: make(map[*websocket.Conn]bool),
|
rooms: make(map[string]map[*websocket.Conn]bool),
|
||||||
broadcast: make(chan []byte),
|
connRoom: make(map[*websocket.Conn]string),
|
||||||
register: make(chan *websocket.Conn),
|
broadcast: make(chan broadcastMsg),
|
||||||
|
register: make(chan client),
|
||||||
unregister: make(chan *websocket.Conn),
|
unregister: make(chan *websocket.Conn),
|
||||||
stop: make(chan struct{}),
|
stop: make(chan struct{}),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
@ -39,43 +53,61 @@ func (h *Hub) Run() {
|
|||||||
select {
|
select {
|
||||||
case <-h.stop:
|
case <-h.stop:
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
for conn := range h.clients {
|
for conn, room := range h.connRoom {
|
||||||
conn.WriteMessage(websocket.CloseMessage,
|
conn.WriteMessage(websocket.CloseMessage,
|
||||||
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
|
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
|
||||||
conn.Close()
|
conn.Close()
|
||||||
delete(h.clients, conn)
|
if clients, ok := h.rooms[room]; ok {
|
||||||
|
delete(clients, conn)
|
||||||
|
if len(clients) == 0 {
|
||||||
|
delete(h.rooms, room)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(h.connRoom, conn)
|
||||||
}
|
}
|
||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
metrics.ConnectedClients.Set(0)
|
metrics.ConnectedClients.Set(0)
|
||||||
h.logger.Info("Hub stopped, all clients disconnected")
|
h.logger.Info("Hub stopped, all clients disconnected")
|
||||||
return
|
return
|
||||||
|
|
||||||
case conn := <-h.register:
|
case c := <-h.register:
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
h.clients[conn] = true
|
if h.rooms[c.room] == nil {
|
||||||
|
h.rooms[c.room] = make(map[*websocket.Conn]bool)
|
||||||
|
}
|
||||||
|
h.rooms[c.room][c.conn] = true
|
||||||
|
h.connRoom[c.conn] = c.room
|
||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
|
||||||
metrics.ConnectionsTotal.Inc()
|
metrics.ConnectionsTotal.Inc()
|
||||||
h.logger.Infof("Client connected. Total: %d", len(h.clients))
|
h.logger.Infof("Client connected (room=%q). Total: %d", c.room, len(h.connRoom))
|
||||||
|
|
||||||
case conn := <-h.unregister:
|
case conn := <-h.unregister:
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
if _, ok := h.clients[conn]; ok {
|
if room, ok := h.connRoom[conn]; ok {
|
||||||
delete(h.clients, conn)
|
if clients, ok := h.rooms[room]; ok {
|
||||||
|
delete(clients, conn)
|
||||||
|
if len(clients) == 0 {
|
||||||
|
delete(h.rooms, room)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(h.connRoom, conn)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
|
||||||
metrics.DisconnectionsTotal.Inc()
|
metrics.DisconnectionsTotal.Inc()
|
||||||
h.logger.Infof("Client disconnected. Total: %d", len(h.clients))
|
h.logger.Infof("Client disconnected. Total: %d", len(h.connRoom))
|
||||||
|
|
||||||
case message := <-h.broadcast:
|
case message := <-h.broadcast:
|
||||||
metrics.MessagesTotal.Inc()
|
metrics.MessagesTotal.Inc()
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
var failed []*websocket.Conn
|
var failed []*websocket.Conn
|
||||||
for conn := range h.clients {
|
if clients, ok := h.rooms[message.room]; ok {
|
||||||
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
for conn := range clients {
|
||||||
failed = append(failed, conn)
|
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
|
||||||
|
failed = append(failed, conn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
h.mu.RUnlock()
|
h.mu.RUnlock()
|
||||||
@ -83,12 +115,18 @@ func (h *Hub) Run() {
|
|||||||
// Remove failed clients properly so metrics stay consistent
|
// Remove failed clients properly so metrics stay consistent
|
||||||
for _, conn := range failed {
|
for _, conn := range failed {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
if _, ok := h.clients[conn]; ok {
|
if room, ok := h.connRoom[conn]; ok {
|
||||||
delete(h.clients, conn)
|
if clients, ok := h.rooms[room]; ok {
|
||||||
|
delete(clients, conn)
|
||||||
|
if len(clients) == 0 {
|
||||||
|
delete(h.rooms, room)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(h.connRoom, conn)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
metrics.ConnectedClients.Set(float64(len(h.connRoom)))
|
||||||
metrics.DisconnectionsTotal.Inc()
|
metrics.DisconnectionsTotal.Inc()
|
||||||
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients))
|
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.connRoom))
|
||||||
}
|
}
|
||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
}
|
}
|
||||||
@ -108,7 +146,9 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.register <- conn
|
room := r.URL.Query().Get("room")
|
||||||
|
|
||||||
|
h.register <- client{conn: conn, room: room}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -120,7 +160,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
h.broadcast <- message
|
h.broadcast <- broadcastMsg{room: room, data: message}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@ -128,5 +168,12 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (h *Hub) ClientCount() int {
|
func (h *Hub) ClientCount() int {
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
defer h.mu.RUnlock()
|
defer h.mu.RUnlock()
|
||||||
return len(h.clients)
|
return len(h.connRoom)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomCount returns the number of active rooms.
|
||||||
|
func (h *Hub) RoomCount() int {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return len(h.rooms)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,8 +17,11 @@ func TestNew(t *testing.T) {
|
|||||||
if h == nil {
|
if h == nil {
|
||||||
t.Fatal("New returned nil")
|
t.Fatal("New returned nil")
|
||||||
}
|
}
|
||||||
if h.clients == nil {
|
if h.rooms == nil {
|
||||||
t.Error("clients map not initialized")
|
t.Error("rooms map not initialized")
|
||||||
|
}
|
||||||
|
if h.connRoom == nil {
|
||||||
|
t.Error("connRoom map not initialized")
|
||||||
}
|
}
|
||||||
if h.broadcast == nil {
|
if h.broadcast == nil {
|
||||||
t.Error("broadcast channel not initialized")
|
t.Error("broadcast channel not initialized")
|
||||||
@ -38,13 +41,23 @@ func TestClientCount(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoomCount(t *testing.T) {
|
||||||
|
h := New(newTestLogger())
|
||||||
|
go h.Run()
|
||||||
|
defer h.Shutdown()
|
||||||
|
|
||||||
|
if count := h.RoomCount(); count != 0 {
|
||||||
|
t.Errorf("Expected 0 rooms, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBroadcastChannel(t *testing.T) {
|
func TestBroadcastChannel(t *testing.T) {
|
||||||
h := New(newTestLogger())
|
h := New(newTestLogger())
|
||||||
go h.Run()
|
go h.Run()
|
||||||
defer h.Shutdown()
|
defer h.Shutdown()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case h.broadcast <- []byte("test"):
|
case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}:
|
||||||
// Channel is working
|
// Channel is working
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Error("broadcast channel blocked")
|
t.Error("broadcast channel blocked")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user