Refactor the stop case in Hub.Run() to iterate h.rooms directly
instead of h.connRoom. For each room, iterate all connections and
send CloseGoingAway frame before closing. After the loop, reset both
maps (h.rooms, h.connRoom) in one shot rather than deleting entries
incrementally. This is cleaner and avoids modifying a map during
iteration.
Add TestIntegration_GracefulShutdownMultiRoom to verify clients in
separate rooms all receive close frames during shutdown.
🤖 Assisted by the code-assist SOP
185 lines
4.2 KiB
Go
185 lines
4.2 KiB
Go
package hub
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"websocket-relay/internal/logging"
|
|
"websocket-relay/internal/metrics"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
|
|
// client represents a WebSocket connection associated with a room.
|
|
type client struct {
|
|
conn *websocket.Conn
|
|
room string
|
|
}
|
|
|
|
// broadcastMsg represents a message to be broadcast to all clients in a room.
|
|
type broadcastMsg struct {
|
|
room string
|
|
data []byte
|
|
}
|
|
|
|
type Hub struct {
|
|
rooms map[string]map[*websocket.Conn]bool
|
|
connRoom map[*websocket.Conn]string
|
|
broadcast chan broadcastMsg
|
|
register chan client
|
|
unregister chan *websocket.Conn
|
|
stop chan struct{}
|
|
mu sync.RWMutex
|
|
logger *logging.Logger
|
|
}
|
|
|
|
func New(logger *logging.Logger) *Hub {
|
|
return &Hub{
|
|
rooms: make(map[string]map[*websocket.Conn]bool),
|
|
connRoom: make(map[*websocket.Conn]string),
|
|
broadcast: make(chan broadcastMsg),
|
|
register: make(chan client),
|
|
unregister: make(chan *websocket.Conn),
|
|
stop: make(chan struct{}),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (h *Hub) Run() {
|
|
for {
|
|
select {
|
|
case <-h.stop:
|
|
h.mu.Lock()
|
|
for _, clients := range h.rooms {
|
|
for conn := range clients {
|
|
conn.WriteMessage(websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
|
|
conn.Close()
|
|
}
|
|
}
|
|
h.rooms = make(map[string]map[*websocket.Conn]bool)
|
|
h.connRoom = make(map[*websocket.Conn]string)
|
|
h.mu.Unlock()
|
|
metrics.ConnectedClients.Set(0)
|
|
h.logger.Info("Hub stopped, all clients disconnected")
|
|
return
|
|
|
|
case c := <-h.register:
|
|
h.mu.Lock()
|
|
if h.rooms[c.room] == nil {
|
|
h.rooms[c.room] = make(map[*websocket.Conn]bool)
|
|
}
|
|
h.rooms[c.room][c.conn] = true
|
|
h.connRoom[c.conn] = c.room
|
|
count := len(h.connRoom)
|
|
h.mu.Unlock()
|
|
metrics.ConnectedClients.Inc()
|
|
metrics.ConnectionsTotal.Inc()
|
|
h.logger.Infof("Client connected (room=%q). Total: %d", c.room, count)
|
|
|
|
case conn := <-h.unregister:
|
|
h.mu.Lock()
|
|
if room, ok := h.connRoom[conn]; ok {
|
|
if clients, ok := h.rooms[room]; ok {
|
|
delete(clients, conn)
|
|
if len(clients) == 0 {
|
|
delete(h.rooms, room)
|
|
}
|
|
}
|
|
delete(h.connRoom, conn)
|
|
count := len(h.connRoom)
|
|
h.mu.Unlock()
|
|
conn.Close()
|
|
metrics.ConnectedClients.Dec()
|
|
metrics.DisconnectionsTotal.Inc()
|
|
h.logger.Infof("Client disconnected (room=%q). Total: %d", room, count)
|
|
} else {
|
|
h.mu.Unlock()
|
|
conn.Close()
|
|
}
|
|
|
|
case message := <-h.broadcast:
|
|
metrics.MessagesTotal.Inc()
|
|
h.mu.RLock()
|
|
var failed []*websocket.Conn
|
|
if clients, ok := h.rooms[message.room]; ok {
|
|
for conn := range clients {
|
|
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
|
|
failed = append(failed, conn)
|
|
}
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
|
|
// Remove failed clients properly so metrics stay consistent
|
|
for _, conn := range failed {
|
|
h.mu.Lock()
|
|
if room, ok := h.connRoom[conn]; ok {
|
|
if clients, ok := h.rooms[room]; ok {
|
|
delete(clients, conn)
|
|
if len(clients) == 0 {
|
|
delete(h.rooms, room)
|
|
}
|
|
}
|
|
delete(h.connRoom, conn)
|
|
count := len(h.connRoom)
|
|
h.mu.Unlock()
|
|
conn.Close()
|
|
metrics.ConnectedClients.Dec()
|
|
metrics.DisconnectionsTotal.Inc()
|
|
h.logger.Warnf("Client disconnected (write error, room=%q). Total: %d", room, count)
|
|
} else {
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Shutdown gracefully stops the hub, closing all client connections.
|
|
func (h *Hub) Shutdown() {
|
|
close(h.stop)
|
|
}
|
|
|
|
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
h.logger.Errorf("WebSocket upgrade error: %v", err)
|
|
return
|
|
}
|
|
|
|
room := r.URL.Path
|
|
|
|
h.register <- client{conn: conn, room: room}
|
|
|
|
go func() {
|
|
defer func() {
|
|
h.unregister <- conn
|
|
}()
|
|
|
|
for {
|
|
_, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
h.broadcast <- broadcastMsg{room: room, data: message}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (h *Hub) ClientCount() int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
return len(h.connRoom)
|
|
}
|
|
|
|
// RoomCount returns the number of active rooms.
|
|
func (h *Hub) RoomCount() int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
return len(h.rooms)
|
|
}
|