Compare commits
8 Commits
a40d16dae0
...
5288282ae7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5288282ae7 | ||
|
|
18063fb3ef | ||
|
|
516f8c5008 | ||
|
|
c226bdab7f | ||
|
|
5bd08409dc | ||
|
|
48d47dfc92 | ||
|
|
8eaba398dc | ||
|
|
03f379c73c |
@ -9,14 +9,68 @@
|
||||
|
||||
<body>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 20px;
|
||||
}
|
||||
|
||||
.room-controls {
|
||||
margin-bottom: 15px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.room-controls label {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.room-controls input {
|
||||
padding: 6px 10px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
width: 200px;
|
||||
}
|
||||
|
||||
.room-controls button {
|
||||
padding: 6px 14px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
background-color: #4a90d9;
|
||||
color: white;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.room-controls button:hover {
|
||||
background-color: #357abd;
|
||||
}
|
||||
|
||||
.room-status {
|
||||
font-size: 13px;
|
||||
color: #555;
|
||||
padding: 4px 10px;
|
||||
background-color: #e8f4e8;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.room-status.disconnected {
|
||||
background-color: #f4e8e8;
|
||||
}
|
||||
|
||||
.chat {
|
||||
width: 600px;
|
||||
height: 500px;
|
||||
border: 1px solid black;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 6px;
|
||||
overflow-y: scroll;
|
||||
background-color: beige;
|
||||
background-color: #fafaf5;
|
||||
padding: 15px;
|
||||
box-sizing: content-box;
|
||||
}
|
||||
|
||||
.chat .bubble {
|
||||
@ -28,58 +82,223 @@
|
||||
padding: 5px 10px;
|
||||
}
|
||||
|
||||
.chat .system {
|
||||
width: calc(100% - 30px);
|
||||
font-size: 11px;
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
margin-bottom: 8px;
|
||||
padding: 3px 10px;
|
||||
}
|
||||
|
||||
.chat .error {
|
||||
color: red;
|
||||
margin: 10px 0;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.message-controls {
|
||||
margin-top: 10px;
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
width: 600px;
|
||||
}
|
||||
|
||||
.message-controls textarea {
|
||||
flex: 1;
|
||||
padding: 8px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
resize: vertical;
|
||||
min-height: 40px;
|
||||
}
|
||||
|
||||
.message-controls button {
|
||||
padding: 8px 18px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
background-color: #4a90d9;
|
||||
color: white;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
align-self: flex-end;
|
||||
}
|
||||
|
||||
.message-controls button:hover {
|
||||
background-color: #357abd;
|
||||
}
|
||||
|
||||
.message-controls button:disabled {
|
||||
background-color: #aaa;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
</style>
|
||||
<div>
|
||||
<h1>P2P Chat</h1>
|
||||
<div class="room-controls">
|
||||
<label for="room-input">Room:</label>
|
||||
<input type="text" id="room-input" placeholder="/" value="/">
|
||||
<button id="join-btn">Join Room</button>
|
||||
<span id="room-status" class="room-status disconnected">Disconnected</span>
|
||||
</div>
|
||||
<div class="chat" id="box"></div>
|
||||
<textarea id="message"></textarea>
|
||||
<br />
|
||||
<button id="send" disabled>Send</button>
|
||||
<div class="message-controls">
|
||||
<textarea id="message" placeholder="Type a message..."></textarea>
|
||||
<button id="send" disabled>Send</button>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
const chat = document.getElementById("box");
|
||||
const message = document.getElementById("message");
|
||||
const btn = document.getElementById("send");
|
||||
const sendBtn = document.getElementById("send");
|
||||
const roomInput = document.getElementById("room-input");
|
||||
const joinBtn = document.getElementById("join-btn");
|
||||
const roomStatus = document.getElementById("room-status");
|
||||
const name = Date.now().toString(36);
|
||||
|
||||
let retry = 1000;
|
||||
let ws;
|
||||
let currentRoom = "/";
|
||||
|
||||
function connect() {
|
||||
ws = new WebSocket('ws://localhost:8443/');
|
||||
// Derive initial room from URL hash (e.g., index.html#/chat → /chat)
|
||||
function getRoomFromHash() {
|
||||
const hash = window.location.hash.slice(1); // remove the '#'
|
||||
if (hash && hash.startsWith("/")) {
|
||||
return hash;
|
||||
}
|
||||
return "/";
|
||||
}
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
console.log('Received:', event.data);
|
||||
chat.innerHTML += `<div class="bubble">${event.data}</div>`;
|
||||
function updateStatus(connected, room) {
|
||||
if (connected) {
|
||||
roomStatus.textContent = "Connected to " + room;
|
||||
roomStatus.classList.remove("disconnected");
|
||||
} else {
|
||||
roomStatus.textContent = "Disconnected";
|
||||
roomStatus.classList.add("disconnected");
|
||||
}
|
||||
}
|
||||
|
||||
function addSystemMessage(text) {
|
||||
chat.innerHTML += '<div class="system">' + text + '</div>';
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
|
||||
function connect(room) {
|
||||
// Close existing connection if any
|
||||
if (ws) {
|
||||
ws.onmessage = null;
|
||||
ws.onopen = null;
|
||||
ws.onerror = null;
|
||||
ws.onclose = null;
|
||||
ws.close();
|
||||
ws = null;
|
||||
}
|
||||
|
||||
// Normalize room path
|
||||
if (!room.startsWith("/")) {
|
||||
room = "/" + room;
|
||||
}
|
||||
currentRoom = room;
|
||||
roomInput.value = room;
|
||||
|
||||
// Update URL hash for shareable links
|
||||
window.location.hash = room;
|
||||
|
||||
// Build WebSocket URL
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const host = window.location.hostname || "localhost";
|
||||
const port = "8443";
|
||||
const wsUrl = protocol + "//" + host + ":" + port + room;
|
||||
|
||||
addSystemMessage("Connecting to room <b>" + room + "</b>...");
|
||||
sendBtn.setAttribute("disabled", "disabled");
|
||||
updateStatus(false, room);
|
||||
|
||||
ws = new WebSocket(wsUrl);
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
chat.innerHTML += '<div class="bubble">' + event.data + '</div>';
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
};
|
||||
ws.onopen = () => {
|
||||
btn.removeAttribute("disabled");
|
||||
ws.send(`${name} joined the chat`);
|
||||
|
||||
ws.onopen = function() {
|
||||
retry = 1000;
|
||||
sendBtn.removeAttribute("disabled");
|
||||
updateStatus(true, room);
|
||||
addSystemMessage("Joined room <b>" + room + "</b>");
|
||||
ws.send(name + " joined the chat");
|
||||
};
|
||||
ws.onerror = (ev) => {
|
||||
btn.setAttribute("disabled", "disabled");
|
||||
chat.innerHTML += `<div class="error">Failed to connect to websocket, retrying...</div>`;
|
||||
console.error(ev);
|
||||
delete ws.onmessage;
|
||||
delete ws.onopen;
|
||||
setTimeout(connect, retry);
|
||||
retry *= 2;
|
||||
|
||||
ws.onerror = function(ev) {
|
||||
sendBtn.setAttribute("disabled", "disabled");
|
||||
updateStatus(false, room);
|
||||
chat.innerHTML += '<div class="error">Connection error, retrying in ' + (retry / 1000) + 's...</div>';
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
console.error("WebSocket error:", ev);
|
||||
};
|
||||
|
||||
ws.onclose = function(ev) {
|
||||
sendBtn.setAttribute("disabled", "disabled");
|
||||
updateStatus(false, currentRoom);
|
||||
if (!ev.wasClean) {
|
||||
setTimeout(function() { connect(currentRoom); }, retry);
|
||||
retry = Math.min(retry * 2, 30000);
|
||||
}
|
||||
};
|
||||
}
|
||||
btn.addEventListener("click", (ev) => {
|
||||
|
||||
// Join Room button handler
|
||||
joinBtn.addEventListener("click", function() {
|
||||
var room = roomInput.value.trim() || "/";
|
||||
retry = 1000;
|
||||
chat.innerHTML = "";
|
||||
connect(room);
|
||||
});
|
||||
|
||||
// Allow Enter key in room input to join
|
||||
roomInput.addEventListener("keydown", function(ev) {
|
||||
if (ev.key === "Enter") {
|
||||
ev.preventDefault();
|
||||
joinBtn.click();
|
||||
}
|
||||
});
|
||||
|
||||
// Send message handler
|
||||
sendBtn.addEventListener("click", function(ev) {
|
||||
ev.preventDefault();
|
||||
const msg = message.value.trim();
|
||||
var msg = message.value.trim();
|
||||
if (!msg) {
|
||||
return;
|
||||
}
|
||||
const data = `${name}<br>${msg}`;
|
||||
var data = name + "<br>" + msg;
|
||||
ws.send(data);
|
||||
message.value = "";
|
||||
});
|
||||
connect();
|
||||
|
||||
// Allow Ctrl+Enter to send
|
||||
message.addEventListener("keydown", function(ev) {
|
||||
if (ev.key === "Enter" && (ev.ctrlKey || ev.metaKey)) {
|
||||
ev.preventDefault();
|
||||
sendBtn.click();
|
||||
}
|
||||
});
|
||||
|
||||
// Listen for hash changes (e.g., user edits URL)
|
||||
window.addEventListener("hashchange", function() {
|
||||
var room = getRoomFromHash();
|
||||
if (room !== currentRoom) {
|
||||
retry = 1000;
|
||||
chat.innerHTML = "";
|
||||
connect(room);
|
||||
}
|
||||
});
|
||||
|
||||
// Initial connection
|
||||
var initialRoom = getRoomFromHash();
|
||||
roomInput.value = initialRoom;
|
||||
connect(initialRoom);
|
||||
</script>
|
||||
</body>
|
||||
|
||||
|
||||
@ -13,10 +13,23 @@ var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// client represents a WebSocket connection associated with a room.
|
||||
type client struct {
|
||||
conn *websocket.Conn
|
||||
room string
|
||||
}
|
||||
|
||||
// broadcastMsg represents a message to be broadcast to all clients in a room.
|
||||
type broadcastMsg struct {
|
||||
room string
|
||||
data []byte
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
clients map[*websocket.Conn]bool
|
||||
broadcast chan []byte
|
||||
register chan *websocket.Conn
|
||||
rooms map[string]map[*websocket.Conn]bool
|
||||
connRoom map[*websocket.Conn]string
|
||||
broadcast chan broadcastMsg
|
||||
register chan client
|
||||
unregister chan *websocket.Conn
|
||||
stop chan struct{}
|
||||
mu sync.RWMutex
|
||||
@ -25,9 +38,10 @@ type Hub struct {
|
||||
|
||||
func New(logger *logging.Logger) *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
broadcast: make(chan []byte),
|
||||
register: make(chan *websocket.Conn),
|
||||
rooms: make(map[string]map[*websocket.Conn]bool),
|
||||
connRoom: make(map[*websocket.Conn]string),
|
||||
broadcast: make(chan broadcastMsg),
|
||||
register: make(chan client),
|
||||
unregister: make(chan *websocket.Conn),
|
||||
stop: make(chan struct{}),
|
||||
logger: logger,
|
||||
@ -39,43 +53,63 @@ func (h *Hub) Run() {
|
||||
select {
|
||||
case <-h.stop:
|
||||
h.mu.Lock()
|
||||
for conn := range h.clients {
|
||||
conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
|
||||
conn.Close()
|
||||
delete(h.clients, conn)
|
||||
for _, clients := range h.rooms {
|
||||
for conn := range clients {
|
||||
conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseGoingAway, "server shutting down"))
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
h.rooms = make(map[string]map[*websocket.Conn]bool)
|
||||
h.connRoom = make(map[*websocket.Conn]string)
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(0)
|
||||
h.logger.Info("Hub stopped, all clients disconnected")
|
||||
return
|
||||
|
||||
case conn := <-h.register:
|
||||
case c := <-h.register:
|
||||
h.mu.Lock()
|
||||
h.clients[conn] = true
|
||||
if h.rooms[c.room] == nil {
|
||||
h.rooms[c.room] = make(map[*websocket.Conn]bool)
|
||||
}
|
||||
h.rooms[c.room][c.conn] = true
|
||||
h.connRoom[c.conn] = c.room
|
||||
count := len(h.connRoom)
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.ConnectedClients.Inc()
|
||||
metrics.ConnectionsTotal.Inc()
|
||||
h.logger.Infof("Client connected. Total: %d", len(h.clients))
|
||||
h.logger.Infof("Client connected (room=%q). Total: %d", c.room, count)
|
||||
|
||||
case conn := <-h.unregister:
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[conn]; ok {
|
||||
delete(h.clients, conn)
|
||||
if room, ok := h.connRoom[conn]; ok {
|
||||
if clients, ok := h.rooms[room]; ok {
|
||||
delete(clients, conn)
|
||||
if len(clients) == 0 {
|
||||
delete(h.rooms, room)
|
||||
}
|
||||
}
|
||||
delete(h.connRoom, conn)
|
||||
count := len(h.connRoom)
|
||||
h.mu.Unlock()
|
||||
conn.Close()
|
||||
metrics.ConnectedClients.Dec()
|
||||
metrics.DisconnectionsTotal.Inc()
|
||||
h.logger.Infof("Client disconnected (room=%q). Total: %d", room, count)
|
||||
} else {
|
||||
h.mu.Unlock()
|
||||
conn.Close()
|
||||
}
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.DisconnectionsTotal.Inc()
|
||||
h.logger.Infof("Client disconnected. Total: %d", len(h.clients))
|
||||
|
||||
case message := <-h.broadcast:
|
||||
metrics.MessagesTotal.Inc()
|
||||
h.mu.RLock()
|
||||
var failed []*websocket.Conn
|
||||
for conn := range h.clients {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
failed = append(failed, conn)
|
||||
if clients, ok := h.rooms[message.room]; ok {
|
||||
for conn := range clients {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, message.data); err != nil {
|
||||
failed = append(failed, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
@ -83,14 +117,23 @@ func (h *Hub) Run() {
|
||||
// Remove failed clients properly so metrics stay consistent
|
||||
for _, conn := range failed {
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[conn]; ok {
|
||||
delete(h.clients, conn)
|
||||
if room, ok := h.connRoom[conn]; ok {
|
||||
if clients, ok := h.rooms[room]; ok {
|
||||
delete(clients, conn)
|
||||
if len(clients) == 0 {
|
||||
delete(h.rooms, room)
|
||||
}
|
||||
}
|
||||
delete(h.connRoom, conn)
|
||||
count := len(h.connRoom)
|
||||
h.mu.Unlock()
|
||||
conn.Close()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.ConnectedClients.Dec()
|
||||
metrics.DisconnectionsTotal.Inc()
|
||||
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients))
|
||||
h.logger.Warnf("Client disconnected (write error, room=%q). Total: %d", room, count)
|
||||
} else {
|
||||
h.mu.Unlock()
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -108,7 +151,9 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
h.register <- conn
|
||||
room := r.URL.Path
|
||||
|
||||
h.register <- client{conn: conn, room: room}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
@ -120,7 +165,7 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
h.broadcast <- message
|
||||
h.broadcast <- broadcastMsg{room: room, data: message}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -128,5 +173,12 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Hub) ClientCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.clients)
|
||||
return len(h.connRoom)
|
||||
}
|
||||
|
||||
// RoomCount returns the number of active rooms.
|
||||
func (h *Hub) RoomCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.rooms)
|
||||
}
|
||||
|
||||
276
internal/hub/hub_edge_cases_test.go
Normal file
276
internal/hub/hub_edge_cases_test.go
Normal file
@ -0,0 +1,276 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// TestIntegration_PathWithSlashes verifies that clients connected to a nested
|
||||
// path (e.g., /a/b/c) can communicate within the same room.
|
||||
func TestIntegration_PathWithSlashes(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect two clients to a nested path
|
||||
conn1 := dialWSPath(t, server, "/a/b/c")
|
||||
defer conn1.Close()
|
||||
conn2 := dialWSPath(t, server, "/a/b/c")
|
||||
defer conn2.Close()
|
||||
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
|
||||
// Verify they are in the same room
|
||||
testMsg := "nested path message"
|
||||
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message: %v", err)
|
||||
}
|
||||
|
||||
// Both clients should receive it
|
||||
for i, conn := range []*websocket.Conn{conn1, conn2} {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client %d failed to read message: %v", i+1, err)
|
||||
}
|
||||
if string(msg) != testMsg {
|
||||
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify a different nested path is a separate room
|
||||
conn3 := dialWSPath(t, server, "/a/b/d")
|
||||
defer conn3.Close()
|
||||
|
||||
waitForClients(t, h, 3, time.Second)
|
||||
|
||||
// Send from /a/b/d
|
||||
otherMsg := "different path message"
|
||||
if err := conn3.WriteMessage(websocket.TextMessage, []byte(otherMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message from /a/b/d: %v", err)
|
||||
}
|
||||
|
||||
// conn3 should get its own message
|
||||
conn3.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn3.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client 3 failed to read own message: %v", err)
|
||||
}
|
||||
if string(msg) != otherMsg {
|
||||
t.Errorf("Client 3 expected %q, got %q", otherMsg, string(msg))
|
||||
}
|
||||
|
||||
// conn1 and conn2 should NOT receive it
|
||||
for i, conn := range []*websocket.Conn{conn1, conn2} {
|
||||
conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Errorf("Client %d in /a/b/c should NOT have received message from /a/b/d", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_QueryStringIgnored verifies that clients connecting to the
|
||||
// same path with different query strings are placed in the same room.
|
||||
func TestIntegration_QueryStringIgnored(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect client 1 to /room?token=abc
|
||||
conn1 := dialWSPath(t, server, "/room?token=abc")
|
||||
defer conn1.Close()
|
||||
|
||||
// Connect client 2 to /room?token=xyz
|
||||
conn2 := dialWSPath(t, server, "/room?token=xyz")
|
||||
defer conn2.Close()
|
||||
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
|
||||
// They should be in the same room (/room) since query strings are stripped
|
||||
testMsg := "query string test"
|
||||
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message: %v", err)
|
||||
}
|
||||
|
||||
// Both clients should receive the message
|
||||
for i, conn := range []*websocket.Conn{conn1, conn2} {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client %d failed to read message: %v", i+1, err)
|
||||
}
|
||||
if string(msg) != testMsg {
|
||||
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify they are actually in the same room (only 1 room exists)
|
||||
if count := h.RoomCount(); count != 1 {
|
||||
t.Errorf("Expected 1 room (query strings ignored), got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_DefaultRoom verifies that clients connecting to the bare
|
||||
// root path (/) are placed in the default room and can communicate.
|
||||
func TestIntegration_DefaultRoom(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect clients to "/" (bare root)
|
||||
conn1 := dialWSPath(t, server, "/")
|
||||
defer conn1.Close()
|
||||
conn2 := dialWSPath(t, server, "/")
|
||||
defer conn2.Close()
|
||||
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
|
||||
// Verify broadcast works
|
||||
testMsg := "default room message"
|
||||
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message: %v", err)
|
||||
}
|
||||
|
||||
for i, conn := range []*websocket.Conn{conn1, conn2} {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client %d failed to read message: %v", i+1, err)
|
||||
}
|
||||
if string(msg) != testMsg {
|
||||
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify they are in the same room
|
||||
if count := h.RoomCount(); count != 1 {
|
||||
t.Errorf("Expected 1 room for default root path, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ClientDisconnectFromRoom verifies that when one client
|
||||
// disconnects from a room, the remaining clients can still communicate.
|
||||
func TestIntegration_ClientDisconnectFromRoom(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect 3 clients to the same room
|
||||
conn1 := dialWSPath(t, server, "/game")
|
||||
defer conn1.Close()
|
||||
conn2 := dialWSPath(t, server, "/game")
|
||||
defer conn2.Close()
|
||||
conn3 := dialWSPath(t, server, "/game")
|
||||
defer conn3.Close()
|
||||
|
||||
waitForClients(t, h, 3, time.Second)
|
||||
|
||||
// Disconnect client 1
|
||||
conn1.Close()
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
|
||||
// Remaining 2 clients should still communicate
|
||||
testMsg := "still alive"
|
||||
if err := conn2.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message from client 2: %v", err)
|
||||
}
|
||||
|
||||
// conn2 and conn3 should receive the message
|
||||
for i, conn := range []*websocket.Conn{conn2, conn3} {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client %d failed to read message after disconnect: %v", i+2, err)
|
||||
}
|
||||
if string(msg) != testMsg {
|
||||
t.Errorf("Client %d expected %q, got %q", i+2, testMsg, string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the room still exists
|
||||
if count := h.RoomCount(); count != 1 {
|
||||
t.Errorf("Expected room to still exist with 2 clients, got %d rooms", count)
|
||||
}
|
||||
|
||||
// Send from client 3 to confirm bidirectional communication
|
||||
replyMsg := "reply from client 3"
|
||||
if err := conn3.WriteMessage(websocket.TextMessage, []byte(replyMsg)); err != nil {
|
||||
t.Fatalf("Failed to send reply from client 3: %v", err)
|
||||
}
|
||||
|
||||
for i, conn := range []*websocket.Conn{conn2, conn3} {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client %d failed to read reply: %v", i+2, err)
|
||||
}
|
||||
if string(msg) != replyMsg {
|
||||
t.Errorf("Client %d expected %q, got %q", i+2, replyMsg, string(msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ConcurrentRoomOperations verifies that rapidly connecting and
|
||||
// disconnecting clients across multiple rooms concurrently causes no data races.
|
||||
func TestIntegration_ConcurrentRoomOperations(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
const numRooms = 10
|
||||
const clientsPerRoom = 5
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Rapidly connect and disconnect clients across many rooms concurrently
|
||||
for i := 0; i < numRooms; i++ {
|
||||
wg.Add(1)
|
||||
go func(roomIdx int) {
|
||||
defer wg.Done()
|
||||
room := fmt.Sprintf("/concurrent-room-%d", roomIdx)
|
||||
|
||||
conns := make([]*websocket.Conn, 0, clientsPerRoom)
|
||||
|
||||
// Connect all clients in this room
|
||||
for j := 0; j < clientsPerRoom; j++ {
|
||||
conn := dialWSPath(t, server, room)
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
|
||||
// Small delay to allow registration to process
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Send a message from the first client
|
||||
if len(conns) > 0 {
|
||||
msg := fmt.Sprintf("msg from room %d", roomIdx)
|
||||
conns[0].WriteMessage(websocket.TextMessage, []byte(msg))
|
||||
}
|
||||
|
||||
// Small delay then disconnect all
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
for _, conn := range conns {
|
||||
conn.Close()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Wait for all cleanups to complete
|
||||
waitForClients(t, h, 0, 5*time.Second)
|
||||
waitForRooms(t, h, 0, 5*time.Second)
|
||||
|
||||
// Verify clean state: no rooms, no clients
|
||||
if count := h.ClientCount(); count != 0 {
|
||||
t.Errorf("Expected 0 clients after concurrent operations, got %d", count)
|
||||
}
|
||||
if count := h.RoomCount(); count != 0 {
|
||||
t.Errorf("Expected 0 rooms after concurrent operations, got %d", count)
|
||||
}
|
||||
}
|
||||
@ -24,7 +24,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *Hub) {
|
||||
return server, h
|
||||
}
|
||||
|
||||
// helper: dial a WebSocket connection to the test server
|
||||
// helper: dial a WebSocket connection to the test server (default room)
|
||||
func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
|
||||
t.Helper()
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
@ -35,6 +35,17 @@ func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn {
|
||||
return conn
|
||||
}
|
||||
|
||||
// helper: dial a WebSocket connection to a specific room
|
||||
func dialWSWithRoom(t *testing.T, server *httptest.Server, room string) *websocket.Conn {
|
||||
t.Helper()
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/" + room
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial WebSocket (room=%q): %v", room, err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// helper: wait until hub reaches expected client count or timeout
|
||||
func waitForClients(t *testing.T, h *Hub, expected int, timeout time.Duration) {
|
||||
t.Helper()
|
||||
@ -356,3 +367,206 @@ func TestIntegration_LargeMessage(t *testing.T) {
|
||||
t.Errorf("Expected message length %d, got %d", 64*1024, len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RoomIsolation_MessagesOnlyGoToSameRoom(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect 2 clients to room-a, 1 client to room-b
|
||||
connA1 := dialWSWithRoom(t, server, "room-a")
|
||||
defer connA1.Close()
|
||||
connA2 := dialWSWithRoom(t, server, "room-a")
|
||||
defer connA2.Close()
|
||||
connB1 := dialWSWithRoom(t, server, "room-b")
|
||||
defer connB1.Close()
|
||||
|
||||
waitForClients(t, h, 3, time.Second)
|
||||
|
||||
// Send a message from client A1
|
||||
testMsg := "hello room-a"
|
||||
if err := connA1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message: %v", err)
|
||||
}
|
||||
|
||||
// Both room-a clients should receive the message
|
||||
for i, conn := range []*websocket.Conn{connA1, connA2} {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Room-a client %d failed to read message: %v", i+1, err)
|
||||
}
|
||||
if string(msg) != testMsg {
|
||||
t.Errorf("Room-a client %d expected %q, got %q", i+1, testMsg, string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// Room-b client should NOT receive the message
|
||||
connB1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
_, _, err := connB1.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatal("Room-b client should NOT have received a message from room-a")
|
||||
}
|
||||
// Timeout error is expected — message was correctly not delivered
|
||||
}
|
||||
|
||||
func TestIntegration_RoomIsolation_MultipleRoomsIndependent(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect clients to two different rooms
|
||||
connA := dialWSWithRoom(t, server, "alpha")
|
||||
defer connA.Close()
|
||||
connB := dialWSWithRoom(t, server, "beta")
|
||||
defer connB.Close()
|
||||
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
|
||||
// Send message from room alpha
|
||||
msgAlpha := "alpha message"
|
||||
if err := connA.WriteMessage(websocket.TextMessage, []byte(msgAlpha)); err != nil {
|
||||
t.Fatalf("Failed to send alpha message: %v", err)
|
||||
}
|
||||
|
||||
// Room alpha client receives its own message
|
||||
connA.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := connA.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Alpha client failed to read: %v", err)
|
||||
}
|
||||
if string(msg) != msgAlpha {
|
||||
t.Errorf("Alpha client expected %q, got %q", msgAlpha, string(msg))
|
||||
}
|
||||
|
||||
// Send message from room beta
|
||||
msgBeta := "beta message"
|
||||
if err := connB.WriteMessage(websocket.TextMessage, []byte(msgBeta)); err != nil {
|
||||
t.Fatalf("Failed to send beta message: %v", err)
|
||||
}
|
||||
|
||||
// Room beta client receives its own message
|
||||
connB.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err = connB.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Beta client failed to read: %v", err)
|
||||
}
|
||||
if string(msg) != msgBeta {
|
||||
t.Errorf("Beta client expected %q, got %q", msgBeta, string(msg))
|
||||
}
|
||||
|
||||
// Verify alpha did NOT receive beta's message
|
||||
connA.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
_, _, err = connA.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatal("Alpha client should NOT have received beta's message")
|
||||
}
|
||||
|
||||
// Verify beta did NOT receive alpha's message
|
||||
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
_, _, err = connB.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatal("Beta client should NOT have received alpha's message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_BroadcastToEmptyRoom(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect a client to room-a and a separate client to verify hub is functional
|
||||
connA := dialWSWithRoom(t, server, "room-a")
|
||||
defer connA.Close()
|
||||
|
||||
waitForClients(t, h, 1, time.Second)
|
||||
|
||||
// Send directly to broadcast channel targeting a non-existent room.
|
||||
// This should be handled gracefully (no panic, no delivery).
|
||||
h.broadcast <- broadcastMsg{room: "/non-existent", data: []byte("ghost message")}
|
||||
|
||||
// Give the hub time to process
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Now send a real message to room-a to confirm hub is still functional
|
||||
h.broadcast <- broadcastMsg{room: "/room-a", data: []byte("real message")}
|
||||
|
||||
connA.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := connA.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Room-a client failed to read real message after empty-room broadcast: %v", err)
|
||||
}
|
||||
if string(msg) != "real message" {
|
||||
t.Errorf("Expected %q, got %q", "real message", string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_GracefulShutdownMultiRoom(t *testing.T) {
|
||||
logger := logging.NewLogger("debug", &bytes.Buffer{})
|
||||
h := New(logger)
|
||||
go h.Run()
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
||||
defer s.Close()
|
||||
|
||||
// Connect clients to different rooms
|
||||
rooms := []string{"room-a", "room-b", "room-c"}
|
||||
conns := make([]*websocket.Conn, 0, len(rooms))
|
||||
for _, room := range rooms {
|
||||
ws := dialWSWithRoom(t, s, room)
|
||||
conns = append(conns, ws)
|
||||
}
|
||||
|
||||
waitForClients(t, h, 3, time.Second)
|
||||
|
||||
// Verify all clients are connected to separate rooms
|
||||
h.mu.RLock()
|
||||
roomCount := len(h.rooms)
|
||||
connCount := len(h.connRoom)
|
||||
h.mu.RUnlock()
|
||||
if roomCount != 3 {
|
||||
t.Fatalf("expected 3 rooms, got %d", roomCount)
|
||||
}
|
||||
if connCount != 3 {
|
||||
t.Fatalf("expected 3 connections, got %d", connCount)
|
||||
}
|
||||
|
||||
// Shutdown the hub — all clients should receive close frame
|
||||
h.Shutdown()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, ws := range conns {
|
||||
wg.Add(1)
|
||||
go func(idx int, c *websocket.Conn) {
|
||||
defer wg.Done()
|
||||
c.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
_, _, err := c.ReadMessage()
|
||||
if err == nil {
|
||||
t.Errorf("client %d: expected error (close frame), got nil", idx)
|
||||
return
|
||||
}
|
||||
closeErr, ok := err.(*websocket.CloseError)
|
||||
if !ok {
|
||||
t.Errorf("client %d: expected CloseError, got %T: %v", idx, err, err)
|
||||
return
|
||||
}
|
||||
if closeErr.Code != websocket.CloseGoingAway {
|
||||
t.Errorf("client %d: expected CloseGoingAway (%d), got %d",
|
||||
idx, websocket.CloseGoingAway, closeErr.Code)
|
||||
}
|
||||
}(i, ws)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify maps are cleared
|
||||
h.mu.RLock()
|
||||
roomsAfter := len(h.rooms)
|
||||
connsAfter := len(h.connRoom)
|
||||
h.mu.RUnlock()
|
||||
if roomsAfter != 0 {
|
||||
t.Errorf("expected 0 rooms after shutdown, got %d", roomsAfter)
|
||||
}
|
||||
if connsAfter != 0 {
|
||||
t.Errorf("expected 0 connections after shutdown, got %d", connsAfter)
|
||||
}
|
||||
}
|
||||
|
||||
293
internal/hub/hub_room_isolation_test.go
Normal file
293
internal/hub/hub_room_isolation_test.go
Normal file
@ -0,0 +1,293 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"websocket-relay/internal/logging"
|
||||
)
|
||||
|
||||
// dialWSPath dials a WebSocket connection to the test server at the given path.
|
||||
// The path should include a leading slash (e.g., "/chat", "/room-a").
|
||||
func dialWSPath(t *testing.T, server *httptest.Server, path string) *websocket.Conn {
|
||||
t.Helper()
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + path
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial WebSocket (path=%q): %v", path, err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// helper: wait until hub reaches expected room count or timeout
|
||||
func waitForRooms(t *testing.T, h *Hub, expected int, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if h.RoomCount() == expected {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("Timed out waiting for %d rooms, got %d", expected, h.RoomCount())
|
||||
}
|
||||
|
||||
func TestIntegration_SameRoomBroadcast(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect 2 clients to /chat
|
||||
conn1 := dialWSPath(t, server, "/chat")
|
||||
defer conn1.Close()
|
||||
conn2 := dialWSPath(t, server, "/chat")
|
||||
defer conn2.Close()
|
||||
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
|
||||
// Client 1 sends message
|
||||
testMsg := "hello chat room"
|
||||
if err := conn1.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message: %v", err)
|
||||
}
|
||||
|
||||
// Both clients should receive it
|
||||
for i, conn := range []*websocket.Conn{conn1, conn2} {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client %d failed to read message: %v", i+1, err)
|
||||
}
|
||||
if string(msg) != testMsg {
|
||||
t.Errorf("Client %d expected %q, got %q", i+1, testMsg, string(msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_CrossRoomIsolation(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect client A to /room-a, client B to /room-b
|
||||
connA := dialWSPath(t, server, "/room-a")
|
||||
defer connA.Close()
|
||||
connB := dialWSPath(t, server, "/room-b")
|
||||
defer connB.Close()
|
||||
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
|
||||
// Client A sends message
|
||||
testMsg := "message from room-a"
|
||||
if err := connA.WriteMessage(websocket.TextMessage, []byte(testMsg)); err != nil {
|
||||
t.Fatalf("Failed to send message from client A: %v", err)
|
||||
}
|
||||
|
||||
// Client A receives it (echo to self within room)
|
||||
connA.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := connA.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Client A failed to read own message: %v", err)
|
||||
}
|
||||
if string(msg) != testMsg {
|
||||
t.Errorf("Client A expected %q, got %q", testMsg, string(msg))
|
||||
}
|
||||
|
||||
// Client B does NOT receive it (verify with read deadline timeout)
|
||||
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
_, _, err = connB.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatal("Client B should NOT have received a message from room-a")
|
||||
}
|
||||
// Timeout error is expected — message was correctly isolated
|
||||
}
|
||||
|
||||
func TestIntegration_MultipleRoomsSimultaneous(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// 3 rooms with 2 clients each
|
||||
type roomClients struct {
|
||||
path string
|
||||
conns []*websocket.Conn
|
||||
}
|
||||
|
||||
rooms := []roomClients{
|
||||
{path: "/room-1"},
|
||||
{path: "/room-2"},
|
||||
{path: "/room-3"},
|
||||
}
|
||||
|
||||
for i := range rooms {
|
||||
for j := 0; j < 2; j++ {
|
||||
conn := dialWSPath(t, server, rooms[i].path)
|
||||
defer conn.Close()
|
||||
rooms[i].conns = append(rooms[i].conns, conn)
|
||||
}
|
||||
}
|
||||
|
||||
waitForClients(t, h, 6, time.Second)
|
||||
waitForRooms(t, h, 3, time.Second)
|
||||
|
||||
// Send a message in each room from the first client
|
||||
for i, room := range rooms {
|
||||
msg := fmt.Sprintf("message for %s", room.path)
|
||||
if err := room.conns[0].WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
|
||||
t.Fatalf("Room %d: failed to send message: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify each room's clients receive only their room's message
|
||||
for i, room := range rooms {
|
||||
expectedMsg := fmt.Sprintf("message for %s", room.path)
|
||||
for j, conn := range room.conns {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Room %d, client %d: failed to read message: %v", i+1, j+1, err)
|
||||
}
|
||||
if string(msg) != expectedMsg {
|
||||
t.Errorf("Room %d, client %d: expected %q, got %q", i+1, j+1, expectedMsg, string(msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cross-room isolation: send another message from room-1 and confirm
|
||||
// room-2 and room-3 clients don't receive it
|
||||
isolationMsg := "room-1 only"
|
||||
if err := rooms[0].conns[0].WriteMessage(websocket.TextMessage, []byte(isolationMsg)); err != nil {
|
||||
t.Fatalf("Failed to send isolation test message: %v", err)
|
||||
}
|
||||
|
||||
// Room-1 clients should get it
|
||||
for j, conn := range rooms[0].conns {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Room-1, client %d: failed to read isolation message: %v", j+1, err)
|
||||
}
|
||||
if string(msg) != isolationMsg {
|
||||
t.Errorf("Room-1, client %d: expected %q, got %q", j+1, isolationMsg, string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// Room-2 and room-3 clients should NOT get it
|
||||
for i := 1; i < len(rooms); i++ {
|
||||
for j, conn := range rooms[i].conns {
|
||||
conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Errorf("Room %d, client %d: should NOT have received room-1's message", i+1, j+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RoomCleanup(t *testing.T) {
|
||||
logger := logging.NewLogger("debug", &bytes.Buffer{})
|
||||
h := New(logger)
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
||||
defer server.Close()
|
||||
|
||||
// Record initial room count
|
||||
initialRooms := h.RoomCount()
|
||||
|
||||
// Connect client to /temp
|
||||
conn := dialWSPath(t, server, "/temp")
|
||||
|
||||
waitForClients(t, h, 1, time.Second)
|
||||
waitForRooms(t, h, initialRooms+1, time.Second)
|
||||
|
||||
// Verify RoomCount increased
|
||||
if count := h.RoomCount(); count != initialRooms+1 {
|
||||
t.Errorf("Expected room count to be %d, got %d", initialRooms+1, count)
|
||||
}
|
||||
|
||||
// Disconnect the client
|
||||
conn.Close()
|
||||
|
||||
// Wait for cleanup
|
||||
waitForClients(t, h, 0, time.Second)
|
||||
waitForRooms(t, h, initialRooms, time.Second)
|
||||
|
||||
// Verify RoomCount decreased (room removed)
|
||||
if count := h.RoomCount(); count != initialRooms {
|
||||
t.Errorf("Expected room count to return to %d after disconnect, got %d", initialRooms, count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RoomCleanup_MultipleClients(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Connect two clients to the same room
|
||||
conn1 := dialWSPath(t, server, "/shared")
|
||||
conn2 := dialWSPath(t, server, "/shared")
|
||||
|
||||
waitForClients(t, h, 2, time.Second)
|
||||
waitForRooms(t, h, 1, time.Second)
|
||||
|
||||
// Disconnect first client — room should still exist
|
||||
conn1.Close()
|
||||
waitForClients(t, h, 1, time.Second)
|
||||
|
||||
if count := h.RoomCount(); count != 1 {
|
||||
t.Errorf("Expected room to still exist with 1 client remaining, got %d rooms", count)
|
||||
}
|
||||
|
||||
// Disconnect second client — room should be cleaned up
|
||||
conn2.Close()
|
||||
waitForClients(t, h, 0, time.Second)
|
||||
waitForRooms(t, h, 0, time.Second)
|
||||
|
||||
if count := h.RoomCount(); count != 0 {
|
||||
t.Errorf("Expected room to be removed after all clients disconnect, got %d rooms", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RoomCleanup_ConcurrentDisconnects(t *testing.T) {
|
||||
server, h := setupTestServer(t)
|
||||
defer server.Close()
|
||||
defer h.Shutdown()
|
||||
|
||||
const numRooms = 5
|
||||
conns := make([]*websocket.Conn, numRooms)
|
||||
for i := 0; i < numRooms; i++ {
|
||||
path := fmt.Sprintf("/room-%d", i)
|
||||
conns[i] = dialWSPath(t, server, path)
|
||||
}
|
||||
|
||||
waitForClients(t, h, numRooms, time.Second)
|
||||
waitForRooms(t, h, numRooms, time.Second)
|
||||
|
||||
// Disconnect all clients concurrently
|
||||
var wg sync.WaitGroup
|
||||
for _, conn := range conns {
|
||||
wg.Add(1)
|
||||
go func(c *websocket.Conn) {
|
||||
defer wg.Done()
|
||||
c.Close()
|
||||
}(conn)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Wait for all cleanup to complete
|
||||
waitForClients(t, h, 0, 2*time.Second)
|
||||
waitForRooms(t, h, 0, 2*time.Second)
|
||||
|
||||
if count := h.RoomCount(); count != 0 {
|
||||
t.Errorf("Expected 0 rooms after all disconnects, got %d", count)
|
||||
}
|
||||
}
|
||||
@ -2,9 +2,13 @@ package hub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"websocket-relay/internal/logging"
|
||||
)
|
||||
|
||||
@ -12,13 +16,32 @@ func newTestLogger() *logging.Logger {
|
||||
return logging.NewLogger("debug", &bytes.Buffer{})
|
||||
}
|
||||
|
||||
// dialTestHub starts an httptest server for the given hub and dials a
|
||||
// WebSocket connection to it with the given room path.
|
||||
// Returns the client-side connection and a cleanup function.
|
||||
func dialTestHub(t *testing.T, h *Hub, room string) *websocket.Conn {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/" + room
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial WebSocket: %v", err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
if h == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
if h.clients == nil {
|
||||
t.Error("clients map not initialized")
|
||||
if h.rooms == nil {
|
||||
t.Error("rooms map not initialized")
|
||||
}
|
||||
if h.connRoom == nil {
|
||||
t.Error("connRoom map not initialized")
|
||||
}
|
||||
if h.broadcast == nil {
|
||||
t.Error("broadcast channel not initialized")
|
||||
@ -38,13 +61,23 @@ func TestClientCount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoomCount(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
if count := h.RoomCount(); count != 0 {
|
||||
t.Errorf("Expected 0 rooms, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastChannel(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
select {
|
||||
case h.broadcast <- []byte("test"):
|
||||
case h.broadcast <- broadcastMsg{room: "", data: []byte("test")}:
|
||||
// Channel is working
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("broadcast channel blocked")
|
||||
@ -72,3 +105,178 @@ func TestShutdown(t *testing.T) {
|
||||
t.Fatal("Hub.Run() did not return after Shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterClient(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
conn := dialTestHub(t, h, "test-room")
|
||||
defer conn.Close()
|
||||
|
||||
// Allow register to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.ClientCount(); count != 1 {
|
||||
t.Errorf("Expected 1 client, got %d", count)
|
||||
}
|
||||
if count := h.RoomCount(); count != 1 {
|
||||
t.Errorf("Expected 1 room, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnregisterClient(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
conn := dialTestHub(t, h, "test-room")
|
||||
|
||||
// Allow register to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.ClientCount(); count != 1 {
|
||||
t.Errorf("Expected 1 client after register, got %d", count)
|
||||
}
|
||||
|
||||
// Close the client-side connection to trigger unregister
|
||||
conn.Close()
|
||||
|
||||
// Allow unregister to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.ClientCount(); count != 0 {
|
||||
t.Errorf("Expected 0 clients after unregister, got %d", count)
|
||||
}
|
||||
if count := h.RoomCount(); count != 0 {
|
||||
t.Errorf("Expected 0 rooms after last client leaves, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterMultipleRooms(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
conn1 := dialTestHub(t, h, "room-a")
|
||||
defer conn1.Close()
|
||||
conn2 := dialTestHub(t, h, "room-a")
|
||||
defer conn2.Close()
|
||||
conn3 := dialTestHub(t, h, "room-b")
|
||||
defer conn3.Close()
|
||||
|
||||
// Allow all registers to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.ClientCount(); count != 3 {
|
||||
t.Errorf("Expected 3 clients, got %d", count)
|
||||
}
|
||||
if count := h.RoomCount(); count != 2 {
|
||||
t.Errorf("Expected 2 rooms, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnregisterCleansUpEmptyRoom(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
conn1 := dialTestHub(t, h, "shared-room")
|
||||
conn2 := dialTestHub(t, h, "shared-room")
|
||||
|
||||
// Allow registers to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.RoomCount(); count != 1 {
|
||||
t.Errorf("Expected 1 room, got %d", count)
|
||||
}
|
||||
|
||||
// Remove first client — room should still exist
|
||||
conn1.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.ClientCount(); count != 1 {
|
||||
t.Errorf("Expected 1 client after first disconnect, got %d", count)
|
||||
}
|
||||
if count := h.RoomCount(); count != 1 {
|
||||
t.Errorf("Expected room to still exist with 1 client, got %d rooms", count)
|
||||
}
|
||||
|
||||
// Remove second client — room should be cleaned up
|
||||
conn2.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.ClientCount(); count != 0 {
|
||||
t.Errorf("Expected 0 clients after both disconnect, got %d", count)
|
||||
}
|
||||
if count := h.RoomCount(); count != 0 {
|
||||
t.Errorf("Expected room to be cleaned up, got %d rooms", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnregisterUnknownConnNoPanic(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
// Create a raw WebSocket connection that is NOT registered with the hub
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Send directly to unregister without ever registering
|
||||
h.unregister <- conn
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Allow unregister to be processed — should not panic
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count := h.ClientCount(); count != 0 {
|
||||
t.Errorf("Expected 0 clients, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastRoomIsolation(t *testing.T) {
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
connA := dialTestHub(t, h, "room-a")
|
||||
defer connA.Close()
|
||||
connB := dialTestHub(t, h, "room-b")
|
||||
defer connB.Close()
|
||||
|
||||
// Allow registers to be processed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Send message to room-a via broadcast channel (room includes leading slash from URL path)
|
||||
h.broadcast <- broadcastMsg{room: "/room-a", data: []byte("for-a-only")}
|
||||
|
||||
// Room-a client should receive the message
|
||||
connA.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, msg, err := connA.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Room-a client failed to read: %v", err)
|
||||
}
|
||||
if string(msg) != "for-a-only" {
|
||||
t.Errorf("Expected %q, got %q", "for-a-only", string(msg))
|
||||
}
|
||||
|
||||
// Room-b client should NOT receive the message (short timeout)
|
||||
connB.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
_, _, err = connB.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatal("Room-b client should NOT have received room-a's message")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user