Add tests verifying that the broadcast case in Hub.Run() correctly
sends messages only to clients in the same room as the sender:
- TestIntegration_RoomIsolation_MessagesOnlyGoToSameRoom: verifies
messages from room-a are received by room-a clients and NOT by
room-b clients
- TestIntegration_RoomIsolation_MultipleRoomsIndependent: verifies
two rooms operate independently with no message leakage
- TestIntegration_BroadcastToEmptyRoom: verifies graceful handling
when broadcasting to a non-existent room (no panic, hub remains
functional)
- TestBroadcastRoomIsolation: unit-level room isolation test using
the broadcast channel directly
Also adds dialWSWithRoom helper for room-aware WebSocket connections
in integration tests.
🤖 Assisted by the code-assist SOP
283 lines
6.7 KiB
Go
283 lines
6.7 KiB
Go
package hub
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"websocket-relay/internal/logging"
|
|
)
|
|
|
|
func newTestLogger() *logging.Logger {
|
|
return logging.NewLogger("debug", &bytes.Buffer{})
|
|
}
|
|
|
|
// dialTestHub starts an httptest server for the given hub and dials a
|
|
// WebSocket connection to it with the given room query parameter.
|
|
// Returns the client-side connection and a cleanup function.
|
|
func dialTestHub(t *testing.T, h *Hub, room string) *websocket.Conn {
|
|
t.Helper()
|
|
srv := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
|
t.Cleanup(srv.Close)
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "?room=" + room
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to dial WebSocket: %v", err)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
func TestNew(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
if h == nil {
|
|
t.Fatal("New returned nil")
|
|
}
|
|
if h.rooms == nil {
|
|
t.Error("rooms map not initialized")
|
|
}
|
|
if h.connRoom == nil {
|
|
t.Error("connRoom map not initialized")
|
|
}
|
|
if h.broadcast == nil {
|
|
t.Error("broadcast channel not initialized")
|
|
}
|
|
if h.stop == nil {
|
|
t.Error("stop channel not initialized")
|
|
}
|
|
}
|
|
|
|
func TestClientCount(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
if count := h.ClientCount(); count != 0 {
|
|
t.Errorf("Expected 0 clients, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestRoomCount(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
if count := h.RoomCount(); count != 0 {
|
|
t.Errorf("Expected 0 rooms, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestBroadcastChannel(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
select {
|
|
case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}:
|
|
// Channel is working
|
|
case <-time.After(100 * time.Millisecond):
|
|
t.Error("broadcast channel blocked")
|
|
}
|
|
}
|
|
|
|
func TestShutdown(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
h.Run()
|
|
close(done)
|
|
}()
|
|
|
|
// Ensure Run is processing before shutdown
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
h.Shutdown()
|
|
|
|
select {
|
|
case <-done:
|
|
// Hub.Run() returned successfully
|
|
case <-time.After(1 * time.Second):
|
|
t.Fatal("Hub.Run() did not return after Shutdown")
|
|
}
|
|
}
|
|
|
|
func TestRegisterClient(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
conn := dialTestHub(t, h, "test-room")
|
|
defer conn.Close()
|
|
|
|
// Allow register to be processed
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.ClientCount(); count != 1 {
|
|
t.Errorf("Expected 1 client, got %d", count)
|
|
}
|
|
if count := h.RoomCount(); count != 1 {
|
|
t.Errorf("Expected 1 room, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestUnregisterClient(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
conn := dialTestHub(t, h, "test-room")
|
|
|
|
// Allow register to be processed
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.ClientCount(); count != 1 {
|
|
t.Errorf("Expected 1 client after register, got %d", count)
|
|
}
|
|
|
|
// Close the client-side connection to trigger unregister
|
|
conn.Close()
|
|
|
|
// Allow unregister to be processed
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.ClientCount(); count != 0 {
|
|
t.Errorf("Expected 0 clients after unregister, got %d", count)
|
|
}
|
|
if count := h.RoomCount(); count != 0 {
|
|
t.Errorf("Expected 0 rooms after last client leaves, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestRegisterMultipleRooms(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialTestHub(t, h, "room-a")
|
|
defer conn1.Close()
|
|
conn2 := dialTestHub(t, h, "room-a")
|
|
defer conn2.Close()
|
|
conn3 := dialTestHub(t, h, "room-b")
|
|
defer conn3.Close()
|
|
|
|
// Allow all registers to be processed
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.ClientCount(); count != 3 {
|
|
t.Errorf("Expected 3 clients, got %d", count)
|
|
}
|
|
if count := h.RoomCount(); count != 2 {
|
|
t.Errorf("Expected 2 rooms, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestUnregisterCleansUpEmptyRoom(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
conn1 := dialTestHub(t, h, "shared-room")
|
|
conn2 := dialTestHub(t, h, "shared-room")
|
|
|
|
// Allow registers to be processed
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.RoomCount(); count != 1 {
|
|
t.Errorf("Expected 1 room, got %d", count)
|
|
}
|
|
|
|
// Remove first client — room should still exist
|
|
conn1.Close()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.ClientCount(); count != 1 {
|
|
t.Errorf("Expected 1 client after first disconnect, got %d", count)
|
|
}
|
|
if count := h.RoomCount(); count != 1 {
|
|
t.Errorf("Expected room to still exist with 1 client, got %d rooms", count)
|
|
}
|
|
|
|
// Remove second client — room should be cleaned up
|
|
conn2.Close()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.ClientCount(); count != 0 {
|
|
t.Errorf("Expected 0 clients after both disconnect, got %d", count)
|
|
}
|
|
if count := h.RoomCount(); count != 0 {
|
|
t.Errorf("Expected room to be cleaned up, got %d rooms", count)
|
|
}
|
|
}
|
|
|
|
func TestUnregisterUnknownConnNoPanic(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
// Create a raw WebSocket connection that is NOT registered with the hub
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
// Send directly to unregister without ever registering
|
|
h.unregister <- conn
|
|
}))
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to dial: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Allow unregister to be processed — should not panic
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count := h.ClientCount(); count != 0 {
|
|
t.Errorf("Expected 0 clients, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestBroadcastRoomIsolation(t *testing.T) {
|
|
h := New(newTestLogger())
|
|
go h.Run()
|
|
defer h.Shutdown()
|
|
|
|
connA := dialTestHub(t, h, "room-a")
|
|
defer connA.Close()
|
|
connB := dialTestHub(t, h, "room-b")
|
|
defer connB.Close()
|
|
|
|
// Allow registers to be processed
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send message to room-a via broadcast channel
|
|
h.broadcast <- broadcastMsg{room: "room-a", data: []byte("for-a-only")}
|
|
|
|
// Room-a client should receive the message
|
|
connA.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, msg, err := connA.ReadMessage()
|
|
if err != nil {
|
|
t.Fatalf("Room-a client failed to read: %v", err)
|
|
}
|
|
if string(msg) != "for-a-only" {
|
|
t.Errorf("Expected %q, got %q", "for-a-only", string(msg))
|
|
}
|
|
|
|
// Room-b client should NOT receive the message (short timeout)
|
|
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
|
_, _, err = connB.ReadMessage()
|
|
if err == nil {
|
|
t.Fatal("Room-b client should NOT have received room-a's message")
|
|
}
|
|
}
|
|
|