package hub import ( "net/http" "sync" "github.com/gorilla/websocket" "websocket-relay/internal/logging" "websocket-relay/internal/metrics" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } // client represents a WebSocket connection associated with a room. type client struct { conn *websocket.Conn room string } // broadcastMsg represents a message to be broadcast to all clients in a room. type broadcastMsg struct { room string data []byte } type Hub struct { rooms map[string]map[*websocket.Conn]bool connRoom map[*websocket.Conn]string broadcast chan broadcastMsg register chan client unregister chan *websocket.Conn stop chan struct{} mu sync.RWMutex logger *logging.Logger } func New(logger *logging.Logger) *Hub { return &Hub{ rooms: make(map[string]map[*websocket.Conn]bool), connRoom: make(map[*websocket.Conn]string), broadcast: make(chan broadcastMsg), register: make(chan client), unregister: make(chan *websocket.Conn), stop: make(chan struct{}), logger: logger, } } func (h *Hub) Run() { for { select { case <-h.stop: h.mu.Lock() for conn, room := range h.connRoom { conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down")) conn.Close() if clients, ok := h.rooms[room]; ok { delete(clients, conn) if len(clients) == 0 { delete(h.rooms, room) } } delete(h.connRoom, conn) } h.mu.Unlock() metrics.ConnectedClients.Set(0) h.logger.Info("Hub stopped, all clients disconnected") return case c := <-h.register: h.mu.Lock() if h.rooms[c.room] == nil { h.rooms[c.room] = make(map[*websocket.Conn]bool) } h.rooms[c.room][c.conn] = true h.connRoom[c.conn] = c.room count := len(h.connRoom) h.mu.Unlock() metrics.ConnectedClients.Inc() metrics.ConnectionsTotal.Inc() h.logger.Infof("Client connected (room=%q). Total: %d", c.room, count) case conn := <-h.unregister: h.mu.Lock() if room, ok := h.connRoom[conn]; ok { if clients, ok := h.rooms[room]; ok { delete(clients, conn) if len(clients) == 0 { delete(h.rooms, room) } } delete(h.connRoom, conn) count := len(h.connRoom) h.mu.Unlock() conn.Close() metrics.ConnectedClients.Dec() metrics.DisconnectionsTotal.Inc() h.logger.Infof("Client disconnected (room=%q). Total: %d", room, count) } else { h.mu.Unlock() conn.Close() } case message := <-h.broadcast: metrics.MessagesTotal.Inc() h.mu.RLock() var failed []*websocket.Conn if clients, ok := h.rooms[message.room]; ok { for conn := range clients { if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil { failed = append(failed, conn) } } } h.mu.RUnlock() // Remove failed clients properly so metrics stay consistent for _, conn := range failed { h.mu.Lock() if room, ok := h.connRoom[conn]; ok { if clients, ok := h.rooms[room]; ok { delete(clients, conn) if len(clients) == 0 { delete(h.rooms, room) } } delete(h.connRoom, conn) count := len(h.connRoom) h.mu.Unlock() conn.Close() metrics.ConnectedClients.Dec() metrics.DisconnectionsTotal.Inc() h.logger.Warnf("Client disconnected (write error, room=%q). Total: %d", room, count) } else { h.mu.Unlock() } } } } } // Shutdown gracefully stops the hub, closing all client connections. func (h *Hub) Shutdown() { close(h.stop) } func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { h.logger.Errorf("WebSocket upgrade error: %v", err) return } room := r.URL.Path h.register <- client{conn: conn, room: room} go func() { defer func() { h.unregister <- conn }() for { _, message, err := conn.ReadMessage() if err != nil { break } h.broadcast <- broadcastMsg{room: room, data: message} } }() } func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.connRoom) } // RoomCount returns the number of active rooms. func (h *Hub) RoomCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.rooms) }