Compare commits

..

No commits in common. "5288282ae73cbb46600246df919126c9d4eaa7d4" and "a40d16dae0bb6256f40f2e04b9819b0c64432244" have entirely different histories.

6 changed files with 63 additions and 1325 deletions

View File

@ -9,68 +9,14 @@
<body> <body>
<style> <style>
* {
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
margin: 20px;
}
.room-controls {
margin-bottom: 15px;
display: flex;
align-items: center;
gap: 10px;
}
.room-controls label {
font-weight: bold;
}
.room-controls input {
padding: 6px 10px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 14px;
width: 200px;
}
.room-controls button {
padding: 6px 14px;
border: none;
border-radius: 4px;
background-color: #4a90d9;
color: white;
font-size: 14px;
cursor: pointer;
}
.room-controls button:hover {
background-color: #357abd;
}
.room-status {
font-size: 13px;
color: #555;
padding: 4px 10px;
background-color: #e8f4e8;
border-radius: 4px;
}
.room-status.disconnected {
background-color: #f4e8e8;
}
.chat { .chat {
width: 600px; width: 600px;
height: 500px; height: 500px;
border: 1px solid #ccc; border: 1px solid black;
border-radius: 6px;
overflow-y: scroll; overflow-y: scroll;
background-color: #fafaf5; background-color: beige;
padding: 15px; padding: 15px;
box-sizing: content-box;
} }
.chat .bubble { .chat .bubble {
@ -82,223 +28,58 @@
padding: 5px 10px; padding: 5px 10px;
} }
.chat .system {
width: calc(100% - 30px);
font-size: 11px;
color: #666;
font-style: italic;
margin-bottom: 8px;
padding: 3px 10px;
}
.chat .error { .chat .error {
color: red; color: red;
margin: 10px 0; margin: 10px 0;
font-size: 12px;
}
.message-controls {
margin-top: 10px;
display: flex;
gap: 10px;
width: 600px;
}
.message-controls textarea {
flex: 1;
padding: 8px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 14px;
resize: vertical;
min-height: 40px;
}
.message-controls button {
padding: 8px 18px;
border: none;
border-radius: 4px;
background-color: #4a90d9;
color: white;
font-size: 14px;
cursor: pointer;
align-self: flex-end;
}
.message-controls button:hover {
background-color: #357abd;
}
.message-controls button:disabled {
background-color: #aaa;
cursor: not-allowed;
} }
</style> </style>
<div> <div>
<h1>P2P Chat</h1> <h1>P2P Chat</h1>
<div class="room-controls">
<label for="room-input">Room:</label>
<input type="text" id="room-input" placeholder="/" value="/">
<button id="join-btn">Join Room</button>
<span id="room-status" class="room-status disconnected">Disconnected</span>
</div>
<div class="chat" id="box"></div> <div class="chat" id="box"></div>
<div class="message-controls"> <textarea id="message"></textarea>
<textarea id="message" placeholder="Type a message..."></textarea> <br />
<button id="send" disabled>Send</button> <button id="send" disabled>Send</button>
</div> </div>
</div>
<script> <script>
const chat = document.getElementById("box"); const chat = document.getElementById("box");
const message = document.getElementById("message"); const message = document.getElementById("message");
const sendBtn = document.getElementById("send"); const btn = document.getElementById("send");
const roomInput = document.getElementById("room-input");
const joinBtn = document.getElementById("join-btn");
const roomStatus = document.getElementById("room-status");
const name = Date.now().toString(36); const name = Date.now().toString(36);
let retry = 1000; let retry = 1000;
let ws; let ws;
let currentRoom = "/";
// Derive initial room from URL hash (e.g., index.html#/chat → /chat) function connect() {
function getRoomFromHash() { ws = new WebSocket('ws://localhost:8443/');
const hash = window.location.hash.slice(1); // remove the '#'
if (hash && hash.startsWith("/")) {
return hash;
}
return "/";
}
function updateStatus(connected, room) { ws.onmessage = (event) => {
if (connected) { console.log('Received:', event.data);
roomStatus.textContent = "Connected to " + room; chat.innerHTML += `<div class="bubble">${event.data}</div>`;
roomStatus.classList.remove("disconnected");
} else {
roomStatus.textContent = "Disconnected";
roomStatus.classList.add("disconnected");
}
}
function addSystemMessage(text) {
chat.innerHTML += '<div class="system">' + text + '</div>';
chat.scrollTop = chat.scrollHeight;
}
function connect(room) {
// Close existing connection if any
if (ws) {
ws.onmessage = null;
ws.onopen = null;
ws.onerror = null;
ws.onclose = null;
ws.close();
ws = null;
}
// Normalize room path
if (!room.startsWith("/")) {
room = "/" + room;
}
currentRoom = room;
roomInput.value = room;
// Update URL hash for shareable links
window.location.hash = room;
// Build WebSocket URL
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const host = window.location.hostname || "localhost";
const port = "8443";
const wsUrl = protocol + "//" + host + ":" + port + room;
addSystemMessage("Connecting to room <b>" + room + "</b>...");
sendBtn.setAttribute("disabled", "disabled");
updateStatus(false, room);
ws = new WebSocket(wsUrl);
ws.onmessage = function(event) {
chat.innerHTML += '<div class="bubble">' + event.data + '</div>';
chat.scrollTop = chat.scrollHeight;
}; };
ws.onopen = () => {
ws.onopen = function() { btn.removeAttribute("disabled");
retry = 1000; ws.send(`${name} joined the chat`);
sendBtn.removeAttribute("disabled");
updateStatus(true, room);
addSystemMessage("Joined room <b>" + room + "</b>");
ws.send(name + " joined the chat");
}; };
ws.onerror = (ev) => {
ws.onerror = function(ev) { btn.setAttribute("disabled", "disabled");
sendBtn.setAttribute("disabled", "disabled"); chat.innerHTML += `<div class="error">Failed to connect to websocket, retrying...</div>`;
updateStatus(false, room); console.error(ev);
chat.innerHTML += '<div class="error">Connection error, retrying in ' + (retry / 1000) + 's...</div>'; delete ws.onmessage;
chat.scrollTop = chat.scrollHeight; delete ws.onopen;
console.error("WebSocket error:", ev); setTimeout(connect, retry);
}; retry *= 2;
ws.onclose = function(ev) {
sendBtn.setAttribute("disabled", "disabled");
updateStatus(false, currentRoom);
if (!ev.wasClean) {
setTimeout(function() { connect(currentRoom); }, retry);
retry = Math.min(retry * 2, 30000);
}
}; };
} }
btn.addEventListener("click", (ev) => {
// Join Room button handler
joinBtn.addEventListener("click", function() {
var room = roomInput.value.trim() || "/";
retry = 1000;
chat.innerHTML = "";
connect(room);
});
// Allow Enter key in room input to join
roomInput.addEventListener("keydown", function(ev) {
if (ev.key === "Enter") {
ev.preventDefault(); ev.preventDefault();
joinBtn.click(); const msg = message.value.trim();
}
});
// Send message handler
sendBtn.addEventListener("click", function(ev) {
ev.preventDefault();
var msg = message.value.trim();
if (!msg) { if (!msg) {
return; return;
} }
var data = name + "<br>" + msg; const data = `${name}<br>${msg}`;
ws.send(data); ws.send(data);
message.value = ""; message.value = "";
}); });
connect();
// Allow Ctrl+Enter to send
message.addEventListener("keydown", function(ev) {
if (ev.key === "Enter" && (ev.ctrlKey || ev.metaKey)) {
ev.preventDefault();
sendBtn.click();
}
});
// Listen for hash changes (e.g., user edits URL)
window.addEventListener("hashchange", function() {
var room = getRoomFromHash();
if (room !== currentRoom) {
retry = 1000;
chat.innerHTML = "";
connect(room);
}
});
// Initial connection
var initialRoom = getRoomFromHash();
roomInput.value = initialRoom;
connect(initialRoom);
</script> </script>
</body> </body>

View File

@ -13,23 +13,10 @@ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true }, CheckOrigin: func(r *http.Request) bool { return true },
} }
// client represents a WebSocket connection associated with a room.
type client struct {
conn *websocket.Conn
room string
}
// broadcastMsg represents a message to be broadcast to all clients in a room.
type broadcastMsg struct {
room string
data []byte
}
type Hub struct { type Hub struct {
rooms map[string]map[*websocket.Conn]bool clients map[*websocket.Conn]bool
connRoom map[*websocket.Conn]string broadcast chan []byte
broadcast chan broadcastMsg register chan *websocket.Conn
register chan client
unregister chan *websocket.Conn unregister chan *websocket.Conn
stop chan struct{} stop chan struct{}
mu sync.RWMutex mu sync.RWMutex
@ -38,10 +25,9 @@ type Hub struct {
func New(logger *logging.Logger) *Hub { func New(logger *logging.Logger) *Hub {
return &Hub{ return &Hub{
rooms: make(map[string]map[*websocket.Conn]bool), clients: make(map[*websocket.Conn]bool),
connRoom: make(map[*websocket.Conn]string), broadcast: make(chan []byte),
broadcast: make(chan broadcastMsg), register: make(chan *websocket.Conn),
register: make(chan client),
unregister: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn),
stop: make(chan struct{}), stop: make(chan struct{}),
logger: logger, logger: logger,
@ -53,87 +39,58 @@ func (h *Hub) Run() {
select { select {
case <-h.stop: case <-h.stop:
h.mu.Lock() h.mu.Lock()
for _, clients := range h.rooms { for conn := range h.clients {
for conn := range clients {
conn.WriteMessage(websocket.CloseMessage, conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down")) websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
conn.Close() conn.Close()
delete(h.clients, conn)
} }
}
h.rooms = make(map[string]map[*websocket.Conn]bool)
h.connRoom = make(map[*websocket.Conn]string)
h.mu.Unlock() h.mu.Unlock()
metrics.ConnectedClients.Set(0) metrics.ConnectedClients.Set(0)
h.logger.Info("Hub stopped, all clients disconnected") h.logger.Info("Hub stopped, all clients disconnected")
return return
case c := <-h.register: case conn := <-h.register:
h.mu.Lock() h.mu.Lock()
if h.rooms[c.room] == nil { h.clients[conn] = true
h.rooms[c.room] = make(map[*websocket.Conn]bool)
}
h.rooms[c.room][c.conn] = true
h.connRoom[c.conn] = c.room
count := len(h.connRoom)
h.mu.Unlock() h.mu.Unlock()
metrics.ConnectedClients.Inc() metrics.ConnectedClients.Set(float64(len(h.clients)))
metrics.ConnectionsTotal.Inc() metrics.ConnectionsTotal.Inc()
h.logger.Infof("Client connected (room=%q). Total: %d", c.room, count) h.logger.Infof("Client connected. Total: %d", len(h.clients))
case conn := <-h.unregister: case conn := <-h.unregister:
h.mu.Lock() h.mu.Lock()
if room, ok := h.connRoom[conn]; ok { if _, ok := h.clients[conn]; ok {
if clients, ok := h.rooms[room]; ok { delete(h.clients, conn)
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
count := len(h.connRoom)
h.mu.Unlock()
conn.Close() conn.Close()
metrics.ConnectedClients.Dec() }
h.mu.Unlock()
metrics.ConnectedClients.Set(float64(len(h.clients)))
metrics.DisconnectionsTotal.Inc() metrics.DisconnectionsTotal.Inc()
h.logger.Infof("Client disconnected (room=%q). Total: %d", room, count) h.logger.Infof("Client disconnected. Total: %d", len(h.clients))
} else {
h.mu.Unlock()
conn.Close()
}
case message := <-h.broadcast: case message := <-h.broadcast:
metrics.MessagesTotal.Inc() metrics.MessagesTotal.Inc()
h.mu.RLock() h.mu.RLock()
var failed []*websocket.Conn var failed []*websocket.Conn
if clients, ok := h.rooms[message.room]; ok { for conn := range h.clients {
for conn := range clients { if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
failed = append(failed, conn) failed = append(failed, conn)
} }
} }
}
h.mu.RUnlock() h.mu.RUnlock()
// Remove failed clients properly so metrics stay consistent // Remove failed clients properly so metrics stay consistent
for _, conn := range failed { for _, conn := range failed {
h.mu.Lock() h.mu.Lock()
if room, ok := h.connRoom[conn]; ok { if _, ok := h.clients[conn]; ok {
if clients, ok := h.rooms[room]; ok { delete(h.clients, conn)
delete(clients, conn)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
delete(h.connRoom, conn)
count := len(h.connRoom)
h.mu.Unlock()
conn.Close() conn.Close()
metrics.ConnectedClients.Dec() metrics.ConnectedClients.Set(float64(len(h.clients)))
metrics.DisconnectionsTotal.Inc() metrics.DisconnectionsTotal.Inc()
h.logger.Warnf("Client disconnected (write error, room=%q). Total: %d", room, count) h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients))
} else {
h.mu.Unlock()
} }
h.mu.Unlock()
} }
} }
} }
@ -151,9 +108,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
return return
} }
room := r.URL.Path h.register <- conn
h.register <- client{conn: conn, room: room}
go func() { go func() {
defer func() { defer func() {
@ -165,7 +120,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
break break
} }
h.broadcast <- broadcastMsg{room: room, data: message} h.broadcast <- message
} }
}() }()
} }
@ -173,12 +128,5 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
func (h *Hub) ClientCount() int { func (h *Hub) ClientCount() int {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
return len(h.connRoom) return len(h.clients)
}
// RoomCount returns the number of active rooms.
func (h *Hub) RoomCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.rooms)
} }

View File

@ -1,276 +0,0 @@
package hub
import (
"fmt"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
)
// TestIntegration_PathWithSlashes verifies that clients connected to a nested
// path (e.g., /a/b/c) can communicate within the same room.
func TestIntegration_PathWithSlashes(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect two clients to a nested path
conn1 := dialWSPath(t, server, "/a/b/c")
defer conn1.Close()
conn2 := dialWSPath(t, server, "/a/b/c")
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Verify they are in the same room
testMsg := "nested path message"
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// Both clients should receive it
for i, conn := range []*websocket.Conn{conn1, conn2} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read message: %v", i+1, err)
}
if string(msg) != testMsg {
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
}
}
// Verify a different nested path is a separate room
conn3 := dialWSPath(t, server, "/a/b/d")
defer conn3.Close()
waitForClients(t, h, 3, time.Second)
// Send from /a/b/d
otherMsg := "different path message"
if err := conn3.WriteMessage(websocket.TextMessage, []byte(otherMsg)); err != nil {
t.Fatalf("Failed to send message from /a/b/d: %v", err)
}
// conn3 should get its own message
conn3.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn3.ReadMessage()
if err != nil {
t.Fatalf("Client 3 failed to read own message: %v", err)
}
if string(msg) != otherMsg {
t.Errorf("Client 3 expected %q, got %q", otherMsg, string(msg))
}
// conn1 and conn2 should NOT receive it
for i, conn := range []*websocket.Conn{conn1, conn2} {
conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err := conn.ReadMessage()
if err == nil {
t.Errorf("Client %d in /a/b/c should NOT have received message from /a/b/d", i+1)
}
}
}
// TestIntegration_QueryStringIgnored verifies that clients connecting to the
// same path with different query strings are placed in the same room.
func TestIntegration_QueryStringIgnored(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect client 1 to /room?token=abc
conn1 := dialWSPath(t, server, "/room?token=abc")
defer conn1.Close()
// Connect client 2 to /room?token=xyz
conn2 := dialWSPath(t, server, "/room?token=xyz")
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// They should be in the same room (/room) since query strings are stripped
testMsg := "query string test"
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// Both clients should receive the message
for i, conn := range []*websocket.Conn{conn1, conn2} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read message: %v", i+1, err)
}
if string(msg) != testMsg {
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
}
}
// Verify they are actually in the same room (only 1 room exists)
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected 1 room (query strings ignored), got %d", count)
}
}
// TestIntegration_DefaultRoom verifies that clients connecting to the bare
// root path (/) are placed in the default room and can communicate.
func TestIntegration_DefaultRoom(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect clients to "/" (bare root)
conn1 := dialWSPath(t, server, "/")
defer conn1.Close()
conn2 := dialWSPath(t, server, "/")
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Verify broadcast works
testMsg := "default room message"
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
for i, conn := range []*websocket.Conn{conn1, conn2} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read message: %v", i+1, err)
}
if string(msg) != testMsg {
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
}
}
// Verify they are in the same room
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected 1 room for default root path, got %d", count)
}
}
// TestIntegration_ClientDisconnectFromRoom verifies that when one client
// disconnects from a room, the remaining clients can still communicate.
func TestIntegration_ClientDisconnectFromRoom(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect 3 clients to the same room
conn1 := dialWSPath(t, server, "/game")
defer conn1.Close()
conn2 := dialWSPath(t, server, "/game")
defer conn2.Close()
conn3 := dialWSPath(t, server, "/game")
defer conn3.Close()
waitForClients(t, h, 3, time.Second)
// Disconnect client 1
conn1.Close()
waitForClients(t, h, 2, time.Second)
// Remaining 2 clients should still communicate
testMsg := "still alive"
if err := conn2.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message from client 2: %v", err)
}
// conn2 and conn3 should receive the message
for i, conn := range []*websocket.Conn{conn2, conn3} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read message after disconnect: %v", i+2, err)
}
if string(msg) != testMsg {
t.Errorf("Client %d expected %q, got %q", i+2, testMsg, string(msg))
}
}
// Verify the room still exists
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected room to still exist with 2 clients, got %d rooms", count)
}
// Send from client 3 to confirm bidirectional communication
replyMsg := "reply from client 3"
if err := conn3.WriteMessage(websocket.TextMessage, []byte(replyMsg)); err != nil {
t.Fatalf("Failed to send reply from client 3: %v", err)
}
for i, conn := range []*websocket.Conn{conn2, conn3} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read reply: %v", i+2, err)
}
if string(msg) != replyMsg {
t.Errorf("Client %d expected %q, got %q", i+2, replyMsg, string(msg))
}
}
}
// TestIntegration_ConcurrentRoomOperations verifies that rapidly connecting and
// disconnecting clients across multiple rooms concurrently causes no data races.
func TestIntegration_ConcurrentRoomOperations(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
const numRooms = 10
const clientsPerRoom = 5
var wg sync.WaitGroup
// Rapidly connect and disconnect clients across many rooms concurrently
for i := 0; i < numRooms; i++ {
wg.Add(1)
go func(roomIdx int) {
defer wg.Done()
room := fmt.Sprintf("/concurrent-room-%d", roomIdx)
conns := make([]*websocket.Conn, 0, clientsPerRoom)
// Connect all clients in this room
for j := 0; j < clientsPerRoom; j++ {
conn := dialWSPath(t, server, room)
conns = append(conns, conn)
}
// Small delay to allow registration to process
time.Sleep(50 * time.Millisecond)
// Send a message from the first client
if len(conns) > 0 {
msg := fmt.Sprintf("msg from room %d", roomIdx)
conns[0].WriteMessage(websocket.TextMessage, []byte(msg))
}
// Small delay then disconnect all
time.Sleep(50 * time.Millisecond)
for _, conn := range conns {
conn.Close()
}
}(i)
}
wg.Wait()
// Wait for all cleanups to complete
waitForClients(t, h, 0, 5*time.Second)
waitForRooms(t, h, 0, 5*time.Second)
// Verify clean state: no rooms, no clients
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients after concurrent operations, got %d", count)
}
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms after concurrent operations, got %d", count)
}
}

View File

@ -24,7 +24,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *Hub) {
return server, h return server, h
} }
// helper: dial a WebSocket connection to the test server (default room) // helper: dial a WebSocket connection to the test server
func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn { func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
t.Helper() t.Helper()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
@ -35,17 +35,6 @@ func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
return conn return conn
} }
// helper: dial a WebSocket connection to a specific room
func dialWSWithRoom(t *testing.T, server *httptest.Server, room string) *websocket.Conn {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/" + room
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to dial WebSocket (room=%q): %v", room, err)
}
return conn
}
// helper: wait until hub reaches expected client count or timeout // helper: wait until hub reaches expected client count or timeout
func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) { func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) {
t.Helper() t.Helper()
@ -367,206 +356,3 @@ func TestIntegration_LargeMessage(t *testing.T) {
t.Errorf("Expected message length %d, got %d", 64*1024, len(msg)) t.Errorf("Expected message length %d, got %d", 64*1024, len(msg))
} }
} }
func TestIntegration_RoomIsolation_MessagesOnlyGoToSameRoom(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect 2 clients to room-a, 1 client to room-b
connA1 := dialWSWithRoom(t, server, "room-a")
defer connA1.Close()
connA2 := dialWSWithRoom(t, server, "room-a")
defer connA2.Close()
connB1 := dialWSWithRoom(t, server, "room-b")
defer connB1.Close()
waitForClients(t, h, 3, time.Second)
// Send a message from client A1
testMsg := "hello room-a"
if err := connA1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// Both room-a clients should receive the message
for i, conn := range []*websocket.Conn{connA1, connA2} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Room-a client %d failed to read message: %v", i+1, err)
}
if string(msg) != testMsg {
t.Errorf("Room-a client %d expected %q, got %q", i+1, testMsg, string(msg))
}
}
// Room-b client should NOT receive the message
connB1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err := connB1.ReadMessage()
if err == nil {
t.Fatal("Room-b client should NOT have received a message from room-a")
}
// Timeout error is expected — message was correctly not delivered
}
func TestIntegration_RoomIsolation_MultipleRoomsIndependent(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect clients to two different rooms
connA := dialWSWithRoom(t, server, "alpha")
defer connA.Close()
connB := dialWSWithRoom(t, server, "beta")
defer connB.Close()
waitForClients(t, h, 2, time.Second)
// Send message from room alpha
msgAlpha := "alpha message"
if err := connA.WriteMessage(websocket.TextMessage, []byte(msgAlpha)); err != nil {
t.Fatalf("Failed to send alpha message: %v", err)
}
// Room alpha client receives its own message
connA.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := connA.ReadMessage()
if err != nil {
t.Fatalf("Alpha client failed to read: %v", err)
}
if string(msg) != msgAlpha {
t.Errorf("Alpha client expected %q, got %q", msgAlpha, string(msg))
}
// Send message from room beta
msgBeta := "beta message"
if err := connB.WriteMessage(websocket.TextMessage, []byte(msgBeta)); err != nil {
t.Fatalf("Failed to send beta message: %v", err)
}
// Room beta client receives its own message
connB.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err = connB.ReadMessage()
if err != nil {
t.Fatalf("Beta client failed to read: %v", err)
}
if string(msg) != msgBeta {
t.Errorf("Beta client expected %q, got %q", msgBeta, string(msg))
}
// Verify alpha did NOT receive beta's message
connA.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err = connA.ReadMessage()
if err == nil {
t.Fatal("Alpha client should NOT have received beta's message")
}
// Verify beta did NOT receive alpha's message
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err = connB.ReadMessage()
if err == nil {
t.Fatal("Beta client should NOT have received alpha's message")
}
}
func TestIntegration_BroadcastToEmptyRoom(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect a client to room-a and a separate client to verify hub is functional
connA := dialWSWithRoom(t, server, "room-a")
defer connA.Close()
waitForClients(t, h, 1, time.Second)
// Send directly to broadcast channel targeting a non-existent room.
// This should be handled gracefully (no panic, no delivery).
h.broadcast <- broadcastMsg{room: "/non-existent", data: []byte("ghost message")}
// Give the hub time to process
time.Sleep(50 * time.Millisecond)
// Now send a real message to room-a to confirm hub is still functional
h.broadcast <- broadcastMsg{room: "/room-a", data: []byte("real message")}
connA.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := connA.ReadMessage()
if err != nil {
t.Fatalf("Room-a client failed to read real message after empty-room broadcast: %v", err)
}
if string(msg) != "real message" {
t.Errorf("Expected %q, got %q", "real message", string(msg))
}
}
func TestIntegration_GracefulShutdownMultiRoom(t *testing.T) {
logger := logging.NewLogger("debug", &bytes.Buffer{})
h := New(logger)
go h.Run()
s := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
defer s.Close()
// Connect clients to different rooms
rooms := []string{"room-a", "room-b", "room-c"}
conns := make([]*websocket.Conn, 0, len(rooms))
for _, room := range rooms {
ws := dialWSWithRoom(t, s, room)
conns = append(conns, ws)
}
waitForClients(t, h, 3, time.Second)
// Verify all clients are connected to separate rooms
h.mu.RLock()
roomCount := len(h.rooms)
connCount := len(h.connRoom)
h.mu.RUnlock()
if roomCount != 3 {
t.Fatalf("expected 3 rooms, got %d", roomCount)
}
if connCount != 3 {
t.Fatalf("expected 3 connections, got %d", connCount)
}
// Shutdown the hub — all clients should receive close frame
h.Shutdown()
var wg sync.WaitGroup
for i, ws := range conns {
wg.Add(1)
go func(idx int, c *websocket.Conn) {
defer wg.Done()
c.SetReadDeadline(time.Now().Add(2 * time.Second))
_, _, err := c.ReadMessage()
if err == nil {
t.Errorf("client %d: expected error (close frame), got nil", idx)
return
}
closeErr, ok := err.(*websocket.CloseError)
if !ok {
t.Errorf("client %d: expected CloseError, got %T: %v", idx, err, err)
return
}
if closeErr.Code != websocket.CloseGoingAway {
t.Errorf("client %d: expected CloseGoingAway (%d), got %d",
idx, websocket.CloseGoingAway, closeErr.Code)
}
}(i, ws)
}
wg.Wait()
// Verify maps are cleared
h.mu.RLock()
roomsAfter := len(h.rooms)
connsAfter := len(h.connRoom)
h.mu.RUnlock()
if roomsAfter != 0 {
t.Errorf("expected 0 rooms after shutdown, got %d", roomsAfter)
}
if connsAfter != 0 {
t.Errorf("expected 0 connections after shutdown, got %d", connsAfter)
}
}

View File

@ -1,293 +0,0 @@
package hub
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"websocket-relay/internal/logging"
)
// dialWSPath dials a WebSocket connection to the test server at the given path.
// The path should include a leading slash (e.g., "/chat", "/room-a").
func dialWSPath(t *testing.T, server *httptest.Server, path string) *websocket.Conn {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + path
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to dial WebSocket (path=%q): %v", path, err)
}
return conn
}
// helper: wait until hub reaches expected room count or timeout
func waitForRooms(t *testing.T, h *Hub, expected int, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if h.RoomCount() == expected {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("Timed out waiting for %d rooms, got %d", expected, h.RoomCount())
}
func TestIntegration_SameRoomBroadcast(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect 2 clients to /chat
conn1 := dialWSPath(t, server, "/chat")
defer conn1.Close()
conn2 := dialWSPath(t, server, "/chat")
defer conn2.Close()
waitForClients(t, h, 2, time.Second)
// Client 1 sends message
testMsg := "hello chat room"
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message: %v", err)
}
// Both clients should receive it
for i, conn := range []*websocket.Conn{conn1, conn2} {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Client %d failed to read message: %v", i+1, err)
}
if string(msg) != testMsg {
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
}
}
}
func TestIntegration_CrossRoomIsolation(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect client A to /room-a, client B to /room-b
connA := dialWSPath(t, server, "/room-a")
defer connA.Close()
connB := dialWSPath(t, server, "/room-b")
defer connB.Close()
waitForClients(t, h, 2, time.Second)
// Client A sends message
testMsg := "message from room-a"
if err := connA.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
t.Fatalf("Failed to send message from client A: %v", err)
}
// Client A receives it (echo to self within room)
connA.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := connA.ReadMessage()
if err != nil {
t.Fatalf("Client A failed to read own message: %v", err)
}
if string(msg) != testMsg {
t.Errorf("Client A expected %q, got %q", testMsg, string(msg))
}
// Client B does NOT receive it (verify with read deadline timeout)
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err = connB.ReadMessage()
if err == nil {
t.Fatal("Client B should NOT have received a message from room-a")
}
// Timeout error is expected — message was correctly isolated
}
func TestIntegration_MultipleRoomsSimultaneous(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// 3 rooms with 2 clients each
type roomClients struct {
path string
conns []*websocket.Conn
}
rooms := []roomClients{
{path: "/room-1"},
{path: "/room-2"},
{path: "/room-3"},
}
for i := range rooms {
for j := 0; j < 2; j++ {
conn := dialWSPath(t, server, rooms[i].path)
defer conn.Close()
rooms[i].conns = append(rooms[i].conns, conn)
}
}
waitForClients(t, h, 6, time.Second)
waitForRooms(t, h, 3, time.Second)
// Send a message in each room from the first client
for i, room := range rooms {
msg := fmt.Sprintf("message for %s", room.path)
if err := room.conns[0].WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
t.Fatalf("Room %d: failed to send message: %v", i+1, err)
}
}
// Verify each room's clients receive only their room's message
for i, room := range rooms {
expectedMsg := fmt.Sprintf("message for %s", room.path)
for j, conn := range room.conns {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Room %d, client %d: failed to read message: %v", i+1, j+1, err)
}
if string(msg) != expectedMsg {
t.Errorf("Room %d, client %d: expected %q, got %q", i+1, j+1, expectedMsg, string(msg))
}
}
}
// Verify cross-room isolation: send another message from room-1 and confirm
// room-2 and room-3 clients don't receive it
isolationMsg := "room-1 only"
if err := rooms[0].conns[0].WriteMessage(websocket.TextMessage, []byte(isolationMsg)); err != nil {
t.Fatalf("Failed to send isolation test message: %v", err)
}
// Room-1 clients should get it
for j, conn := range rooms[0].conns {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("Room-1, client %d: failed to read isolation message: %v", j+1, err)
}
if string(msg) != isolationMsg {
t.Errorf("Room-1, client %d: expected %q, got %q", j+1, isolationMsg, string(msg))
}
}
// Room-2 and room-3 clients should NOT get it
for i := 1; i < len(rooms); i++ {
for j, conn := range rooms[i].conns {
conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err := conn.ReadMessage()
if err == nil {
t.Errorf("Room %d, client %d: should NOT have received room-1's message", i+1, j+1)
}
}
}
}
func TestIntegration_RoomCleanup(t *testing.T) {
logger := logging.NewLogger("debug", &bytes.Buffer{})
h := New(logger)
go h.Run()
defer h.Shutdown()
server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
defer server.Close()
// Record initial room count
initialRooms := h.RoomCount()
// Connect client to /temp
conn := dialWSPath(t, server, "/temp")
waitForClients(t, h, 1, time.Second)
waitForRooms(t, h, initialRooms+1, time.Second)
// Verify RoomCount increased
if count := h.RoomCount(); count != initialRooms+1 {
t.Errorf("Expected room count to be %d, got %d", initialRooms+1, count)
}
// Disconnect the client
conn.Close()
// Wait for cleanup
waitForClients(t, h, 0, time.Second)
waitForRooms(t, h, initialRooms, time.Second)
// Verify RoomCount decreased (room removed)
if count := h.RoomCount(); count != initialRooms {
t.Errorf("Expected room count to return to %d after disconnect, got %d", initialRooms, count)
}
}
func TestIntegration_RoomCleanup_MultipleClients(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
// Connect two clients to the same room
conn1 := dialWSPath(t, server, "/shared")
conn2 := dialWSPath(t, server, "/shared")
waitForClients(t, h, 2, time.Second)
waitForRooms(t, h, 1, time.Second)
// Disconnect first client — room should still exist
conn1.Close()
waitForClients(t, h, 1, time.Second)
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected room to still exist with 1 client remaining, got %d rooms", count)
}
// Disconnect second client — room should be cleaned up
conn2.Close()
waitForClients(t, h, 0, time.Second)
waitForRooms(t, h, 0, time.Second)
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected room to be removed after all clients disconnect, got %d rooms", count)
}
}
func TestIntegration_RoomCleanup_ConcurrentDisconnects(t *testing.T) {
server, h := setupTestServer(t)
defer server.Close()
defer h.Shutdown()
const numRooms = 5
conns := make([]*websocket.Conn, numRooms)
for i := 0; i < numRooms; i++ {
path := fmt.Sprintf("/room-%d", i)
conns[i] = dialWSPath(t, server, path)
}
waitForClients(t, h, numRooms, time.Second)
waitForRooms(t, h, numRooms, time.Second)
// Disconnect all clients concurrently
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func(c *websocket.Conn) {
defer wg.Done()
c.Close()
}(conn)
}
wg.Wait()
// Wait for all cleanup to complete
waitForClients(t, h, 0, 2*time.Second)
waitForRooms(t, h, 0, 2*time.Second)
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms after all disconnects, got %d", count)
}
}

View File

@ -2,13 +2,9 @@ package hub
import ( import (
"bytes" "bytes"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"github.com/gorilla/websocket"
"websocket-relay/internal/logging" "websocket-relay/internal/logging"
) )
@ -16,32 +12,13 @@ func newTestLogger() *logging.Logger {
return logging.NewLogger("debug", &bytes.Buffer{}) return logging.NewLogger("debug", &bytes.Buffer{})
} }
// dialTestHub starts an httptest server for the given hub and dials a
// WebSocket connection to it with the given room path.
// Returns the client-side connection and a cleanup function.
func dialTestHub(t *testing.T, h *Hub, room string) *websocket.Conn {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
t.Cleanup(srv.Close)
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/" + room
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to dial WebSocket: %v", err)
}
return conn
}
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
h := New(newTestLogger()) h := New(newTestLogger())
if h == nil { if h == nil {
t.Fatal("New returned nil") t.Fatal("New returned nil")
} }
if h.rooms == nil { if h.clients == nil {
t.Error("rooms map not initialized") t.Error("clients map not initialized")
}
if h.connRoom == nil {
t.Error("connRoom map not initialized")
} }
if h.broadcast == nil { if h.broadcast == nil {
t.Error("broadcast channel not initialized") t.Error("broadcast channel not initialized")
@ -61,23 +38,13 @@ func TestClientCount(t *testing.T) {
} }
} }
func TestRoomCount(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms, got %d", count)
}
}
func TestBroadcastChannel(t *testing.T) { func TestBroadcastChannel(t *testing.T) {
h := New(newTestLogger()) h := New(newTestLogger())
go h.Run() go h.Run()
defer h.Shutdown() defer h.Shutdown()
select { select {
case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}: case h.broadcast <- []byte("test"):
// Channel is working // Channel is working
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
t.Error("broadcast channel blocked") t.Error("broadcast channel blocked")
@ -105,178 +72,3 @@ func TestShutdown(t *testing.T) {
t.Fatal("Hub.Run() did not return after Shutdown") t.Fatal("Hub.Run() did not return after Shutdown")
} }
} }
func TestRegisterClient(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn := dialTestHub(t, h, "test-room")
defer conn.Close()
// Allow register to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client, got %d", count)
}
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected 1 room, got %d", count)
}
}
func TestUnregisterClient(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn := dialTestHub(t, h, "test-room")
// Allow register to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client after register, got %d", count)
}
// Close the client-side connection to trigger unregister
conn.Close()
// Allow unregister to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients after unregister, got %d", count)
}
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected 0 rooms after last client leaves, got %d", count)
}
}
func TestRegisterMultipleRooms(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn1 := dialTestHub(t, h, "room-a")
defer conn1.Close()
conn2 := dialTestHub(t, h, "room-a")
defer conn2.Close()
conn3 := dialTestHub(t, h, "room-b")
defer conn3.Close()
// Allow all registers to be processed
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 3 {
t.Errorf("Expected 3 clients, got %d", count)
}
if count := h.RoomCount(); count != 2 {
t.Errorf("Expected 2 rooms, got %d", count)
}
}
func TestUnregisterCleansUpEmptyRoom(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
conn1 := dialTestHub(t, h, "shared-room")
conn2 := dialTestHub(t, h, "shared-room")
// Allow registers to be processed
time.Sleep(50 * time.Millisecond)
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected 1 room, got %d", count)
}
// Remove first client — room should still exist
conn1.Close()
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 1 {
t.Errorf("Expected 1 client after first disconnect, got %d", count)
}
if count := h.RoomCount(); count != 1 {
t.Errorf("Expected room to still exist with 1 client, got %d rooms", count)
}
// Remove second client — room should be cleaned up
conn2.Close()
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients after both disconnect, got %d", count)
}
if count := h.RoomCount(); count != 0 {
t.Errorf("Expected room to be cleaned up, got %d rooms", count)
}
}
func TestUnregisterUnknownConnNoPanic(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
// Create a raw WebSocket connection that is NOT registered with the hub
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Send directly to unregister without ever registering
h.unregister <- conn
}))
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer conn.Close()
// Allow unregister to be processed — should not panic
time.Sleep(50 * time.Millisecond)
if count := h.ClientCount(); count != 0 {
t.Errorf("Expected 0 clients, got %d", count)
}
}
func TestBroadcastRoomIsolation(t *testing.T) {
h := New(newTestLogger())
go h.Run()
defer h.Shutdown()
connA := dialTestHub(t, h, "room-a")
defer connA.Close()
connB := dialTestHub(t, h, "room-b")
defer connB.Close()
// Allow registers to be processed
time.Sleep(50 * time.Millisecond)
// Send message to room-a via broadcast channel (room includes leading slash from URL path)
h.broadcast <- broadcastMsg{room: "/room-a", data: []byte("for-a-only")}
// Room-a client should receive the message
connA.SetReadDeadline(time.Now().Add(time.Second))
_, msg, err := connA.ReadMessage()
if err != nil {
t.Fatalf("Room-a client failed to read: %v", err)
}
if string(msg) != "for-a-only" {
t.Errorf("Expected %q, got %q", "for-a-only", string(msg))
}
// Room-b client should NOT receive the message (short timeout)
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_, _, err = connB.ReadMessage()
if err == nil {
t.Fatal("Room-b client should NOT have received room-a's message")
}
}