- Fix metrics: change MessagesTotal, ConnectionsTotal, DisconnectionsTotal from Gauge to Counter with proper _total naming convention - Fix broadcast write-error handling: failed clients now get properly removed with accurate metrics updates - Add graceful shutdown: SIGINT/SIGTERM handling with 10s timeout, CloseGoingAway frame sent to clients before disconnect - Add integration tests: 11 tests using real WebSocket connections covering connect, broadcast, disconnect, concurrency, and shutdown - Fix example client port: changed from 8000 to 8443 to match config - Rewrite README.md to reflect current features and usage - Add AGENTS.md and .agents/summary/ documentation for AI assistants
356 lines
8.8 KiB
Go
356 lines
8.8 KiB
Go
package hub
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// helper: start a test server with a running Hub, return the server and hub
|
|
func setupTestServer(t *testing.T) (*httptest.Server, *Hub) {
|
|
t.Helper()
|
|
h := New()
|
|
go h.Run()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
|
return server, h
|
|
}
|
|
|
|
// helper: dial a WebSocket connection to the test server
|
|
func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
|
|
t.Helper()
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to dial WebSocket: %v", err)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
// helper: wait until hub reaches expected client count or timeout
|
|
func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) {
|
|
t.Helper()
|
|
deadline := time.Now().Add(timeout)
|
|
for time.Now().Before(deadline) {
|
|
if h.ClientCount() == expected {
|
|
return
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
t.Fatalf("Timed out waiting for %d clients, got %d", expected, h.ClientCount())
|
|
}
|
|
|
|
func TestIntegration_SingleClientConnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn := dialWS(t, server)
|
|
defer conn.Close()
|
|
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
if count := h.ClientCount(); count != 1 {
|
|
t.Errorf("Expected 1 client, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestIntegration_MultipleClientsConnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
const numClients = 5
|
|
conns := make([]*websocket.Conn, numClients)
|
|
for i := 0; i < numClients; i++ {
|
|
conns[i] = dialWS(t, server)
|
|
defer conns[i].Close()
|
|
}
|
|
|
|
waitForClients(t, h, numClients, time.Second)
|
|
|
|
if count := h.ClientCount(); count != numClients {
|
|
t.Errorf("Expected %d clients, got %d", numClients, count)
|
|
}
|
|
}
|
|
|
|
func TestIntegration_BroadcastMessage(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
// Connect two clients
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Send a message from client 1
|
|
testMsg := "hello from client 1"
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
|
t.Fatalf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// Both clients should receive the broadcast
|
|
for i, conn := range []*websocket.Conn{conn1, conn2} {
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Client %d failed to read message: %v", i+1, err)
|
|
}
|
|
if string(msg) != testMsg {
|
|
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_BroadcastToManyClients(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
const numClients = 10
|
|
conns := make([]*websocket.Conn, numClients)
|
|
for i := 0; i < numClients; i++ {
|
|
conns[i] = dialWS(t, server)
|
|
defer conns[i].Close()
|
|
}
|
|
|
|
waitForClients(t, h, numClients, time.Second)
|
|
|
|
// Send from first client
|
|
testMsg := "broadcast to all"
|
|
if err := conns[0].WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
|
t.Fatalf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// All clients should receive it
|
|
for i, conn := range conns {
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Client %d failed to read: %v", i, err)
|
|
}
|
|
if string(msg) != testMsg {
|
|
t.Errorf("Client %d expected %q, got %q", i, testMsg, string(msg))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_ClientDisconnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Disconnect client 1
|
|
conn1.Close()
|
|
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
if count := h.ClientCount(); count != 1 {
|
|
t.Errorf("Expected 1 client after disconnect, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestIntegration_MessageAfterDisconnect(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Disconnect client 1
|
|
conn1.Close()
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
// Send a message from client 2 — should still work
|
|
testMsg := "after disconnect"
|
|
if err := conn2.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
|
t.Fatalf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// Client 2 should receive its own message back
|
|
conn2.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read message: %v", err)
|
|
}
|
|
if string(msg) != testMsg {
|
|
t.Errorf("Expected %q, got %q", testMsg, string(msg))
|
|
}
|
|
}
|
|
|
|
func TestIntegration_MultipleMessages(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
messages := []string{"first", "second", "third"}
|
|
|
|
for _, msg := range messages {
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
|
|
t.Fatalf("Failed to send %q: %v", msg, err)
|
|
}
|
|
}
|
|
|
|
// Client 2 should receive all messages in order
|
|
for _, expected := range messages {
|
|
conn2.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read message: %v", err)
|
|
}
|
|
if string(msg) != expected {
|
|
t.Errorf("Expected %q, got %q", expected, string(msg))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_ConcurrentSenders(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
const numClients = 5
|
|
conns := make([]*websocket.Conn, numClients)
|
|
for i := 0; i < numClients; i++ {
|
|
conns[i] = dialWS(t, server)
|
|
defer conns[i].Close()
|
|
}
|
|
|
|
waitForClients(t, h, numClients, time.Second)
|
|
|
|
// Each client sends one message concurrently
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numClients; i++ {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
msg := []byte(strings.Repeat("x", idx+1)) // unique length per sender
|
|
conns[idx].WriteMessage(websocket.TextMessage, msg)
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Each client should receive exactly numClients messages (one from each sender)
|
|
for i, conn := range conns {
|
|
received := 0
|
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
for received < numClients {
|
|
_, _, err := conn.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Client %d: read error after %d messages: %v", i, received, err)
|
|
}
|
|
received++
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIntegration_GracefulShutdownClosesClients(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
|
|
conn := dialWS(t, server)
|
|
defer conn.Close()
|
|
|
|
waitForClients(t, h, 1, time.Second)
|
|
|
|
// Trigger shutdown
|
|
h.Shutdown()
|
|
|
|
// Client should receive a close frame
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, _, err := conn.ReadMessage()
|
|
if err == nil {
|
|
t.Fatal("Expected error after shutdown, got nil")
|
|
}
|
|
|
|
// Verify it's a close error with GoingAway code
|
|
if closeErr, ok := err.(*websocket.CloseError); ok {
|
|
if closeErr.Code != websocket.CloseGoingAway {
|
|
t.Errorf("Expected CloseGoingAway (%d), got %d", websocket.CloseGoingAway, closeErr.Code)
|
|
}
|
|
}
|
|
// Any error is acceptable — the key is the connection is no longer usable
|
|
}
|
|
|
|
func TestIntegration_EmptyMessage(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Send an empty message
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte("")); err != nil {
|
|
t.Fatalf("Failed to send empty message: %v", err)
|
|
}
|
|
|
|
// Client 2 should receive the empty message
|
|
conn2.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read message: %v", err)
|
|
}
|
|
if string(msg) != "" {
|
|
t.Errorf("Expected empty message, got %q", string(msg))
|
|
}
|
|
}
|
|
|
|
func TestIntegration_LargeMessage(t *testing.T) {
|
|
server, h := setupTestServer(t)
|
|
defer server.Close()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialWS(t, server)
|
|
defer conn1.Close()
|
|
conn2 := dialWS(t, server)
|
|
defer conn2.Close()
|
|
|
|
waitForClients(t, h, 2, time.Second)
|
|
|
|
// Send a 64KB message
|
|
largeMsg := strings.Repeat("A", 64*1024)
|
|
if err := conn1.WriteMessage(websocket.TextMessage, []byte(largeMsg)); err != nil {
|
|
t.Fatalf("Failed to send large message: %v", err)
|
|
}
|
|
|
|
conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, msg, err := conn2.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Failed to read large message: %v", err)
|
|
}
|
|
if len(msg) != 64*1024 {
|
|
t.Errorf("Expected message length %d, got %d", 64*1024, len(msg))
|
|
}
|
|
}
|