package hub import ( "log" "net/http" "sync" "github.com/gorilla/websocket" "websocket-relay/internal/metrics" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } type Hub struct { clients map[*websocket.Conn]bool broadcast chan []byte register chan *websocket.Conn unregister chan *websocket.Conn mu sync.RWMutex } func New() *Hub { return &Hub{ clients: make(map[*websocket.Conn]bool), broadcast: make(chan []byte), register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), } } func (h *Hub) Run() { for { select { case conn := <-h.register: h.mu.Lock() h.clients[conn] = true h.mu.Unlock() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.ConnectionsTotal.Inc() log.Printf("Client connected. Total: %d", len(h.clients)) case conn := <-h.unregister: h.mu.Lock() if _, ok := h.clients[conn]; ok { delete(h.clients, conn) conn.Close() } h.mu.Unlock() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.DisconnectionsTotal.Inc() log.Printf("Client disconnected. Total: %d", len(h.clients)) case message := <-h.broadcast: metrics.MessagesTotal.Inc() h.mu.RLock() for conn := range h.clients { if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { delete(h.clients, conn) conn.Close() } } h.mu.RUnlock() } } } func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket upgrade error: %v", err) return } h.register <- conn go func() { defer func() { h.unregister <- conn }() for { _, message, err := conn.ReadMessage() if err != nil { break } h.broadcast <- message } }() } func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) }