savinmax c226bdab7f feat(hub): update graceful shutdown to iterate rooms for multi-room cleanup
Refactor the stop case in Hub.Run() to iterate h.rooms directly
instead of h.connRoom. For each room, iterate all connections and
send CloseGoingAway frame before closing. After the loop, reset both
maps (h.rooms, h.connRoom) in one shot rather than deleting entries
incrementally. This is cleaner and avoids modifying a map during
iteration.

Add TestIntegration_GracefulShutdownMultiRoom to verify clients in
separate rooms all receive close frames during shutdown.

🤖 Assisted by the code-assist SOP
2026-06-13 13:26:03 +02:00

185 lines
4.2 KiB
Go

package hub
import (
"net/http"
"sync"
"github.com/gorilla/websocket"
"websocket-relay/internal/logging"
"websocket-relay/internal/metrics"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
// client represents a WebSocket connection associated with a room.
type client struct {
conn *websocket.Conn
room string
}
// broadcastMsg represents a message to be broadcast to all clients in a room.
type broadcastMsg struct {
room string
data []byte
}
type Hub struct {
rooms map[string]map[*websocket.Conn]bool
connRoom map[*websocket.Conn]string
broadcast chan broadcastMsg
register chan client
unregister chan *websocket.Conn
stop chan struct{}
mu sync.RWMutex
logger *logging.Logger
}
func New(logger *logging.Logger) *Hub {
return &Hub{
rooms: make(map[string]map[*websocket.Conn]bool),
connRoom: make(map[*websocket.Conn]string),
broadcast: make(chan broadcastMsg),
register: make(chan client),
unregister: make(chan *websocket.Conn),
stop: make(chan struct{}),
logger: logger,
}
}
func (h *Hub) Run() {
for {
select {
case <-h.stop:
h.mu.Lock()
for _, clients := range h.rooms {
for conn := range clients {
conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
conn.Close()
}
}
h.rooms = make(map[string]map[*websocket.Conn]bool)
h.connRoom = make(map[*websocket.Conn]string)
h.mu.Unlock()
metrics.ConnectedClients.Set(0)
h.logger.Info("Hub stopped, all clients disconnected")
return
case c := <-h.register:
h.mu.Lock()
if h.rooms[c.room] == nil {
h.rooms[c.room] = make(map[*websocket.Conn]bool)
}
h.rooms[c.room][c.conn] = true
h.connRoom[c.conn] = c.room
count := len(h.connRoom)
h.mu.Unlock()
metrics.ConnectedClients.Inc()
metrics.ConnectionsTotal.Inc()
h.logger.Infof("Client connected (room=%q). Total: %d", c.room, count)
case conn := <-h.unregister:
h.mu.Lock()
if room, ok := h.connRoom[conn]; ok {
if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
count := len(h.connRoom)
h.mu.Unlock()
conn.Close()
metrics.ConnectedClients.Dec()
metrics.DisconnectionsTotal.Inc()
h.logger.Infof("Client disconnected (room=%q). Total: %d", room, count)
} else {
h.mu.Unlock()
conn.Close()
}
case message := <-h.broadcast:
metrics.MessagesTotal.Inc()
h.mu.RLock()
var failed []*websocket.Conn
if clients, ok := h.rooms[message.room]; ok {
for conn := range clients {
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
failed = append(failed, conn)
}
}
}
h.mu.RUnlock()
// Remove failed clients properly so metrics stay consistent
for _, conn := range failed {
h.mu.Lock()
if room, ok := h.connRoom[conn]; ok {
if clients, ok := h.rooms[room]; ok {
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
count := len(h.connRoom)
h.mu.Unlock()
conn.Close()
metrics.ConnectedClients.Dec()
metrics.DisconnectionsTotal.Inc()
h.logger.Warnf("Client disconnected (write error, room=%q). Total: %d", room, count)
} else {
h.mu.Unlock()
}
}
}
}
}
// Shutdown gracefully stops the hub, closing all client connections.
func (h *Hub) Shutdown() {
close(h.stop)
}
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Errorf("WebSocket upgrade error: %v", err)
return
}
room := r.URL.Path
h.register <- client{conn: conn, room: room}
go func() {
defer func() {
h.unregister <- conn
}()
for {
_, message, err := conn.ReadMessage()
if err != nil {
break
}
h.broadcast <- broadcastMsg{room: room, data: message}
}
}()
}
func (h *Hub) ClientCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.connRoom)
}
// RoomCount returns the number of active rooms.
func (h *Hub) RoomCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.rooms)
}