Compare commits
2 Commits
905c241daa
...
a40d16dae0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a40d16dae0 | ||
|
|
3d14b7fcb8 |
37
README.md
37
README.md
@ -1,12 +1,13 @@
|
||||
# WebSocket Relay Server
|
||||
|
||||
A minimal Go WebSocket relay server that broadcasts every incoming message to all connected clients. Supports TLS, Prometheus metrics, and graceful shutdown.
|
||||
A minimal Go WebSocket relay server that broadcasts every incoming message to all connected clients. Supports TLS, Prometheus metrics, configurable logging, and graceful shutdown.
|
||||
|
||||
## Features
|
||||
|
||||
- **Fan-out broadcasting** — every message is relayed to all connected clients
|
||||
- **TLS support** — optional `wss://` via cert/key PEM files
|
||||
- **Prometheus metrics** — connection counts, message totals, disconnections
|
||||
- **Configurable logging** — output to stdout, stderr, or file with level filtering
|
||||
- **Graceful shutdown** — clean exit on SIGINT/SIGTERM with client notification
|
||||
- **Zero dependencies at runtime** — single static binary
|
||||
|
||||
@ -40,6 +41,10 @@ server:
|
||||
metrics:
|
||||
enabled: true
|
||||
port: 9090 # Prometheus metrics at :9090/metrics
|
||||
|
||||
logging:
|
||||
output: stderr # stdout, stderr, or a file path
|
||||
level: info # debug, info, warn, error
|
||||
```
|
||||
|
||||
Override the config file path with `--config-file`:
|
||||
@ -48,6 +53,33 @@ Override the config file path with `--config-file`:
|
||||
./websocket-relay --config-file=/etc/relay/config.yaml
|
||||
```
|
||||
|
||||
## Logging
|
||||
|
||||
The `logging` section controls where and what the server logs:
|
||||
|
||||
| Field | Values | Default | Description |
|
||||
|-------|--------|---------|-------------|
|
||||
| `output` | `stdout`, `stderr`, or a file path | `stderr` | Log output destination |
|
||||
| `level` | `debug`, `info`, `warn`, `error` | `info` | Minimum log level to output |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```yaml
|
||||
# Log everything to a file
|
||||
logging:
|
||||
output: /var/log/websocket-relay.log
|
||||
level: debug
|
||||
|
||||
# Quiet mode — only warnings and errors to stderr
|
||||
logging:
|
||||
output: stderr
|
||||
level: warn
|
||||
```
|
||||
|
||||
Log messages are prefixed with the level: `[DEBUG]`, `[INFO]`, `[WARN]`, `[ERROR]`.
|
||||
|
||||
File output uses append mode (`O_APPEND`) so logs are preserved across restarts and safe for external log rotation tools.
|
||||
|
||||
## Usage
|
||||
|
||||
Connect any WebSocket client to the server:
|
||||
@ -110,10 +142,11 @@ websocket-relay/
|
||||
├── internal/
|
||||
│ ├── config/config.go # YAML config loader
|
||||
│ ├── hub/hub.go # WebSocket hub, connection management, broadcast
|
||||
│ ├── logging/logging.go # Log output setup and leveled logger
|
||||
│ └── metrics/metrics.go # Prometheus metric definitions
|
||||
├── example/index.html # Browser P2P chat demo
|
||||
├── config.yaml # Runtime configuration
|
||||
├── config.example.yaml # Example config with TLS enabled
|
||||
├── config.example.yaml # Example config with TLS and logging
|
||||
└── Makefile # Build, test, release commands
|
||||
```
|
||||
|
||||
|
||||
@ -8,3 +8,9 @@ server:
|
||||
metrics:
|
||||
enabled: true
|
||||
port: 9090
|
||||
|
||||
logging:
|
||||
# output: stdout, stderr, or a file path (default: stderr)
|
||||
output: stderr
|
||||
# level: debug, info, warn, error (default: info)
|
||||
level: info
|
||||
|
||||
@ -19,6 +19,10 @@ type Config struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
} `yaml:"metrics"`
|
||||
Logging struct {
|
||||
Output string `yaml:"output"`
|
||||
Level string `yaml:"level"`
|
||||
} `yaml:"logging"`
|
||||
}
|
||||
|
||||
func Load(filename string) (*Config, error) {
|
||||
@ -29,4 +33,4 @@ func Load(filename string) (*Config, error) {
|
||||
var config Config
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
return &config, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,4 +42,94 @@ func TestLoadFileNotFound(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadLoggingOutput(t *testing.T) {
|
||||
testConfig := `server:
|
||||
port: 8443
|
||||
logging:
|
||||
output: /var/log/relay.log
|
||||
level: debug`
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "config_logging_*.yaml")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := tmpFile.WriteString(testConfig); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Logging.Output != "/var/log/relay.log" {
|
||||
t.Errorf("Expected output '/var/log/relay.log', got '%s'", cfg.Logging.Output)
|
||||
}
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("Expected level 'debug', got '%s'", cfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadLoggingDefaults(t *testing.T) {
|
||||
testConfig := `server:
|
||||
port: 8443`
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "config_defaults_*.yaml")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := tmpFile.WriteString(testConfig); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Logging.Output != "" {
|
||||
t.Errorf("Expected empty output default, got '%s'", cfg.Logging.Output)
|
||||
}
|
||||
if cfg.Logging.Level != "" {
|
||||
t.Errorf("Expected empty level default, got '%s'", cfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadLoggingStdout(t *testing.T) {
|
||||
testConfig := `server:
|
||||
port: 8443
|
||||
logging:
|
||||
output: stdout
|
||||
level: warn`
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "config_stdout_*.yaml")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := tmpFile.WriteString(testConfig); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Logging.Output != "stdout" {
|
||||
t.Errorf("Expected output 'stdout', got '%s'", cfg.Logging.Output)
|
||||
}
|
||||
if cfg.Logging.Level != "warn" {
|
||||
t.Errorf("Expected level 'warn', got '%s'", cfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"websocket-relay/internal/logging"
|
||||
"websocket-relay/internal/metrics"
|
||||
)
|
||||
|
||||
@ -20,15 +20,17 @@ type Hub struct {
|
||||
unregister chan *websocket.Conn
|
||||
stop chan struct{}
|
||||
mu sync.RWMutex
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
func New() *Hub {
|
||||
func New(logger *logging.Logger) *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
broadcast: make(chan []byte),
|
||||
register: make(chan *websocket.Conn),
|
||||
unregister: make(chan *websocket.Conn),
|
||||
stop: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,7 +47,7 @@ func (h *Hub) Run() {
|
||||
}
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(0)
|
||||
log.Printf("Hub stopped, all clients disconnected")
|
||||
h.logger.Info("Hub stopped, all clients disconnected")
|
||||
return
|
||||
|
||||
case conn := <-h.register:
|
||||
@ -54,7 +56,7 @@ func (h *Hub) Run() {
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.ConnectionsTotal.Inc()
|
||||
log.Printf("Client connected. Total: %d", len(h.clients))
|
||||
h.logger.Infof("Client connected. Total: %d", len(h.clients))
|
||||
|
||||
case conn := <-h.unregister:
|
||||
h.mu.Lock()
|
||||
@ -65,7 +67,7 @@ func (h *Hub) Run() {
|
||||
h.mu.Unlock()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.DisconnectionsTotal.Inc()
|
||||
log.Printf("Client disconnected. Total: %d", len(h.clients))
|
||||
h.logger.Infof("Client disconnected. Total: %d", len(h.clients))
|
||||
|
||||
case message := <-h.broadcast:
|
||||
metrics.MessagesTotal.Inc()
|
||||
@ -86,7 +88,7 @@ func (h *Hub) Run() {
|
||||
conn.Close()
|
||||
metrics.ConnectedClients.Set(float64(len(h.clients)))
|
||||
metrics.DisconnectionsTotal.Inc()
|
||||
log.Printf("Client disconnected (write error). Total: %d", len(h.clients))
|
||||
h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients))
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
@ -102,7 +104,7 @@ func (h *Hub) Shutdown() {
|
||||
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket upgrade error: %v", err)
|
||||
h.logger.Errorf("WebSocket upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -9,12 +10,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"websocket-relay/internal/logging"
|
||||
)
|
||||
|
||||
// helper: start a test server with a running Hub, return the server and hub
|
||||
func setupTestServer(t *testing.T) (*httptest.Server, *Hub) {
|
||||
t.Helper()
|
||||
h := New()
|
||||
logger := logging.NewLogger("debug", &bytes.Buffer{})
|
||||
h := New(logger)
|
||||
go h.Run()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket))
|
||||
|
||||
@ -1,12 +1,19 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"websocket-relay/internal/logging"
|
||||
)
|
||||
|
||||
func newTestLogger() *logging.Logger {
|
||||
return logging.NewLogger("debug", &bytes.Buffer{})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
h := New()
|
||||
h := New(newTestLogger())
|
||||
if h == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
@ -22,7 +29,7 @@ func TestNew(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientCount(t *testing.T) {
|
||||
h := New()
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
@ -32,7 +39,7 @@ func TestClientCount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBroadcastChannel(t *testing.T) {
|
||||
h := New()
|
||||
h := New(newTestLogger())
|
||||
go h.Run()
|
||||
defer h.Shutdown()
|
||||
|
||||
@ -45,7 +52,7 @@ func TestBroadcastChannel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
h := New()
|
||||
h := New(newTestLogger())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
126
internal/logging/logging.go
Normal file
126
internal/logging/logging.go
Normal file
@ -0,0 +1,126 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Level represents the severity of a log message.
|
||||
type Level int
|
||||
|
||||
const (
|
||||
LevelDebug Level = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
)
|
||||
|
||||
// Logger provides level-aware logging.
|
||||
type Logger struct {
|
||||
level Level
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// Setup configures the global log output destination.
|
||||
// Returns the opened file (if a file path was given) so the caller can defer Close.
|
||||
// For "stdout", "stderr", or empty string, returns nil (no file to close).
|
||||
func Setup(output string) (*os.File, error) {
|
||||
switch strings.ToLower(output) {
|
||||
case "", "stderr":
|
||||
log.SetOutput(os.Stderr)
|
||||
return nil, nil
|
||||
case "stdout":
|
||||
log.SetOutput(os.Stdout)
|
||||
return nil, nil
|
||||
default:
|
||||
file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open log file %s: %w", output, err)
|
||||
}
|
||||
log.SetOutput(file)
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger with the given level and output writer.
|
||||
// If level is empty or invalid, defaults to LevelInfo.
|
||||
func NewLogger(level string, output io.Writer) *Logger {
|
||||
return &Logger{
|
||||
level: parseLevel(level),
|
||||
logger: log.New(output, "", log.LstdFlags),
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs a message at debug level.
|
||||
func (l *Logger) Debug(msg string) {
|
||||
if l.level <= LevelDebug {
|
||||
l.logger.Printf("[DEBUG] %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Debugf logs a formatted message at debug level.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
if l.level <= LevelDebug {
|
||||
l.logger.Printf("[DEBUG] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs a message at info level.
|
||||
func (l *Logger) Info(msg string) {
|
||||
if l.level <= LevelInfo {
|
||||
l.logger.Printf("[INFO] %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Infof logs a formatted message at info level.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
if l.level <= LevelInfo {
|
||||
l.logger.Printf("[INFO] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn logs a message at warn level.
|
||||
func (l *Logger) Warn(msg string) {
|
||||
if l.level <= LevelWarn {
|
||||
l.logger.Printf("[WARN] %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Warnf logs a formatted message at warn level.
|
||||
func (l *Logger) Warnf(format string, args ...interface{}) {
|
||||
if l.level <= LevelWarn {
|
||||
l.logger.Printf("[WARN] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error logs a message at error level.
|
||||
func (l *Logger) Error(msg string) {
|
||||
if l.level <= LevelError {
|
||||
l.logger.Printf("[ERROR] %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Errorf logs a formatted message at error level.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
if l.level <= LevelError {
|
||||
l.logger.Printf("[ERROR] "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func parseLevel(level string) Level {
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return LevelDebug
|
||||
case "info":
|
||||
return LevelInfo
|
||||
case "warn":
|
||||
return LevelWarn
|
||||
case "error":
|
||||
return LevelError
|
||||
default:
|
||||
return LevelInfo
|
||||
}
|
||||
}
|
||||
278
internal/logging/logging_test.go
Normal file
278
internal/logging/logging_test.go
Normal file
@ -0,0 +1,278 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetupStdout(t *testing.T) {
|
||||
file, err := Setup("stdout")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup(stdout) failed: %v", err)
|
||||
}
|
||||
if file != nil {
|
||||
t.Error("Expected nil file for stdout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupStderr(t *testing.T) {
|
||||
file, err := Setup("stderr")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup(stderr) failed: %v", err)
|
||||
}
|
||||
if file != nil {
|
||||
t.Error("Expected nil file for stderr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupEmpty(t *testing.T) {
|
||||
file, err := Setup("")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup('') failed: %v", err)
|
||||
}
|
||||
if file != nil {
|
||||
t.Error("Expected nil file for empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupFilePath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
file, err := Setup(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Setup(file) failed: %v", err)
|
||||
}
|
||||
if file == nil {
|
||||
t.Fatal("Expected non-nil file for file path")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Verify file was created
|
||||
if _, err := os.Stat(logPath); os.IsNotExist(err) {
|
||||
t.Error("Expected log file to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupInvalidPath(t *testing.T) {
|
||||
_, err := Setup("/nonexistent/directory/path/test.log")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLoggerDefaultLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("", &buf)
|
||||
|
||||
if logger.level != LevelInfo {
|
||||
t.Errorf("Expected default level Info, got %d", logger.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLoggerInvalidLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("invalid", &buf)
|
||||
|
||||
if logger.level != LevelInfo {
|
||||
t.Errorf("Expected default level Info for invalid input, got %d", logger.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLoggerDebugLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("debug", &buf)
|
||||
|
||||
if logger.level != LevelDebug {
|
||||
t.Errorf("Expected level Debug, got %d", logger.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLoggerWarnLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("warn", &buf)
|
||||
|
||||
if logger.level != LevelWarn {
|
||||
t.Errorf("Expected level Warn, got %d", logger.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLoggerErrorLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("error", &buf)
|
||||
|
||||
if logger.level != LevelError {
|
||||
t.Errorf("Expected level Error, got %d", logger.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerDebugOutputAtDebugLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("debug", &buf)
|
||||
|
||||
logger.Debug("test message")
|
||||
output := buf.String()
|
||||
|
||||
if !strings.Contains(output, "[DEBUG]") {
|
||||
t.Errorf("Expected [DEBUG] prefix, got: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "test message") {
|
||||
t.Errorf("Expected 'test message' in output, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerDebugSuppressedAtInfoLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("info", &buf)
|
||||
|
||||
logger.Debug("should not appear")
|
||||
output := buf.String()
|
||||
|
||||
if output != "" {
|
||||
t.Errorf("Expected no output for debug at info level, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerInfoOutputAtInfoLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("info", &buf)
|
||||
|
||||
logger.Info("info message")
|
||||
output := buf.String()
|
||||
|
||||
if !strings.Contains(output, "[INFO]") {
|
||||
t.Errorf("Expected [INFO] prefix, got: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "info message") {
|
||||
t.Errorf("Expected 'info message' in output, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerInfoSuppressedAtWarnLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("warn", &buf)
|
||||
|
||||
logger.Info("should not appear")
|
||||
output := buf.String()
|
||||
|
||||
if output != "" {
|
||||
t.Errorf("Expected no output for info at warn level, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerWarnOutputAtWarnLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("warn", &buf)
|
||||
|
||||
logger.Warn("warn message")
|
||||
output := buf.String()
|
||||
|
||||
if !strings.Contains(output, "[WARN]") {
|
||||
t.Errorf("Expected [WARN] prefix, got: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "warn message") {
|
||||
t.Errorf("Expected 'warn message' in output, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerWarnSuppressedAtErrorLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("error", &buf)
|
||||
|
||||
logger.Warn("should not appear")
|
||||
output := buf.String()
|
||||
|
||||
if output != "" {
|
||||
t.Errorf("Expected no output for warn at error level, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerErrorOutputAtErrorLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("error", &buf)
|
||||
|
||||
logger.Error("error message")
|
||||
output := buf.String()
|
||||
|
||||
if !strings.Contains(output, "[ERROR]") {
|
||||
t.Errorf("Expected [ERROR] prefix, got: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "error message") {
|
||||
t.Errorf("Expected 'error message' in output, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerErrorAlwaysOutputs(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("error", &buf)
|
||||
|
||||
logger.Error("critical failure")
|
||||
output := buf.String()
|
||||
|
||||
if !strings.Contains(output, "critical failure") {
|
||||
t.Errorf("Expected error message in output, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerDebugLevelAllMessages(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("debug", &buf)
|
||||
|
||||
logger.Debug("d")
|
||||
logger.Info("i")
|
||||
logger.Warn("w")
|
||||
logger.Error("e")
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[DEBUG]") {
|
||||
t.Error("Expected [DEBUG] in output")
|
||||
}
|
||||
if !strings.Contains(output, "[INFO]") {
|
||||
t.Error("Expected [INFO] in output")
|
||||
}
|
||||
if !strings.Contains(output, "[WARN]") {
|
||||
t.Error("Expected [WARN] in output")
|
||||
}
|
||||
if !strings.Contains(output, "[ERROR]") {
|
||||
t.Error("Expected [ERROR] in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerFormatf(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := NewLogger("info", &buf)
|
||||
|
||||
logger.Infof("count: %d", 42)
|
||||
output := buf.String()
|
||||
|
||||
if !strings.Contains(output, "count: 42") {
|
||||
t.Errorf("Expected formatted output 'count: 42', got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerLevelCaseInsensitive(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
logger := NewLogger("DEBUG", &buf)
|
||||
if logger.level != LevelDebug {
|
||||
t.Errorf("Expected level Debug for 'DEBUG', got %d", logger.level)
|
||||
}
|
||||
|
||||
logger = NewLogger("Info", &buf)
|
||||
if logger.level != LevelInfo {
|
||||
t.Errorf("Expected level Info for 'Info', got %d", logger.level)
|
||||
}
|
||||
|
||||
logger = NewLogger("WARN", &buf)
|
||||
if logger.level != LevelWarn {
|
||||
t.Errorf("Expected level Warn for 'WARN', got %d", logger.level)
|
||||
}
|
||||
|
||||
logger = NewLogger("ERROR", &buf)
|
||||
if logger.level != LevelError {
|
||||
t.Errorf("Expected level Error for 'ERROR', got %d", logger.level)
|
||||
}
|
||||
}
|
||||
37
main.go
37
main.go
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"websocket-relay/internal/config"
|
||||
"websocket-relay/internal/hub"
|
||||
"websocket-relay/internal/logging"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
@ -26,7 +27,19 @@ func main() {
|
||||
log.Fatal("Failed to load config:", err)
|
||||
}
|
||||
|
||||
h := hub.New()
|
||||
// Setup log output destination
|
||||
logFile, err := logging.Setup(cfg.Logging.Output)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to setup logging:", err)
|
||||
}
|
||||
if logFile != nil {
|
||||
defer logFile.Close()
|
||||
}
|
||||
|
||||
// Create leveled logger
|
||||
logger := logging.NewLogger(cfg.Logging.Level, log.Writer())
|
||||
|
||||
h := hub.New(logger)
|
||||
go h.Run()
|
||||
|
||||
// Start metrics server if enabled
|
||||
@ -40,9 +53,9 @@ func main() {
|
||||
Handler: metricsMux,
|
||||
}
|
||||
go func() {
|
||||
log.Printf("Metrics server starting on %s", metricsAddr)
|
||||
logger.Infof("Metrics server starting on %s", metricsAddr)
|
||||
if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Printf("Metrics server error: %v", err)
|
||||
logger.Errorf("Metrics server error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -59,14 +72,16 @@ func main() {
|
||||
// Start the main server in a goroutine
|
||||
go func() {
|
||||
if cfg.Server.TLS.Enabled {
|
||||
log.Printf("WebSocket relay server starting on %s (TLS)", addr)
|
||||
logger.Infof("WebSocket relay server starting on %s (TLS)", addr)
|
||||
if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
logger.Errorf("Server error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
log.Printf("WebSocket relay server starting on %s (HTTP)", addr)
|
||||
logger.Infof("WebSocket relay server starting on %s (HTTP)", addr)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
logger.Errorf("Server error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -75,7 +90,7 @@ func main() {
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
sig := <-quit
|
||||
log.Printf("Received signal %v, shutting down gracefully...", sig)
|
||||
logger.Infof("Received signal %v, shutting down gracefully...", sig)
|
||||
|
||||
// Create a deadline for the shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@ -83,18 +98,18 @@ func main() {
|
||||
|
||||
// Shut down the main HTTP server (stops accepting new connections)
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
log.Printf("HTTP server shutdown error: %v", err)
|
||||
logger.Errorf("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
|
||||
// Shut down the metrics server
|
||||
if metricsServer != nil {
|
||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||
log.Printf("Metrics server shutdown error: %v", err)
|
||||
logger.Errorf("Metrics server shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop the hub and close all WebSocket connections
|
||||
h.Shutdown()
|
||||
|
||||
log.Printf("Server stopped")
|
||||
logger.Info("Server stopped")
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user