From 3d14b7fcb8e53ccc6ac1c324cdd447ae5d6ff6ac Mon Sep 17 00:00:00 2001 From: savinmax Date: Thu, 11 Jun 2026 19:21:20 +0200 Subject: [PATCH] feat(logging): add configurable log output and log level support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a 'logging' section to config.yaml supporting: - output: stderr (default), stdout, or a file path - level: debug, info, warn, error (default: info) Implementation: - New internal/logging package with Setup() for output destination and Logger struct with level-aware Debug/Info/Warn/Error methods - Config struct extended with Logging section (output + level fields) - Hub refactored to accept *logging.Logger via constructor injection - main.go initializes logging early after config load The leveled logger suppresses messages below the configured threshold while maintaining the stdlib log format. File output uses append mode with 0644 permissions for safe log rotation. 🤖 Assisted by the code-assist SOP --- config.example.yaml | 6 + internal/config/config.go | 6 +- internal/config/config_test.go | 92 ++++++++- internal/hub/hub.go | 16 +- internal/hub/hub_integration_test.go | 5 +- internal/hub/hub_test.go | 15 +- internal/logging/logging.go | 126 ++++++++++++ internal/logging/logging_test.go | 278 +++++++++++++++++++++++++++ main.go | 37 ++-- 9 files changed, 556 insertions(+), 25 deletions(-) create mode 100644 internal/logging/logging.go create mode 100644 internal/logging/logging_test.go diff --git a/config.example.yaml b/config.example.yaml index 47f9527..092f210 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -8,3 +8,9 @@ server: metrics: enabled: true port: 9090 + +logging: + # output: stdout, stderr, or a file path (default: stderr) + output: stderr + # level: debug, info, warn, error (default: info) + level: info diff --git a/internal/config/config.go b/internal/config/config.go index 20eff8f..f51def2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,10 @@ type Config struct { Enabled bool `yaml:"enabled"` Port int `yaml:"port"` } `yaml:"metrics"` + Logging struct { + Output string `yaml:"output"` + Level string `yaml:"level"` + } `yaml:"logging"` } func Load(filename string) (*Config, error) { @@ -29,4 +33,4 @@ func Load(filename string) (*Config, error) { var config Config err = yaml.Unmarshal(data, &config) return &config, err -} \ No newline at end of file +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 0212ba3..c523e08 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -42,4 +42,94 @@ func TestLoadFileNotFound(t *testing.T) { if err == nil { t.Error("Expected error for nonexistent file") } -} \ No newline at end of file +} + +func TestLoadLoggingOutput(t *testing.T) { + testConfig := `server: + port: 8443 +logging: + output: /var/log/relay.log + level: debug` + + tmpFile, err := os.CreateTemp("", "config_logging_*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(testConfig); err != nil { + t.Fatal(err) + } + tmpFile.Close() + + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.Logging.Output != "/var/log/relay.log" { + t.Errorf("Expected output '/var/log/relay.log', got '%s'", cfg.Logging.Output) + } + if cfg.Logging.Level != "debug" { + t.Errorf("Expected level 'debug', got '%s'", cfg.Logging.Level) + } +} + +func TestLoadLoggingDefaults(t *testing.T) { + testConfig := `server: + port: 8443` + + tmpFile, err := os.CreateTemp("", "config_defaults_*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(testConfig); err != nil { + t.Fatal(err) + } + tmpFile.Close() + + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.Logging.Output != "" { + t.Errorf("Expected empty output default, got '%s'", cfg.Logging.Output) + } + if cfg.Logging.Level != "" { + t.Errorf("Expected empty level default, got '%s'", cfg.Logging.Level) + } +} + +func TestLoadLoggingStdout(t *testing.T) { + testConfig := `server: + port: 8443 +logging: + output: stdout + level: warn` + + tmpFile, err := os.CreateTemp("", "config_stdout_*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(testConfig); err != nil { + t.Fatal(err) + } + tmpFile.Close() + + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.Logging.Output != "stdout" { + t.Errorf("Expected output 'stdout', got '%s'", cfg.Logging.Output) + } + if cfg.Logging.Level != "warn" { + t.Errorf("Expected level 'warn', got '%s'", cfg.Logging.Level) + } +} diff --git a/internal/hub/hub.go b/internal/hub/hub.go index 23993a6..0bf8f5c 100644 --- a/internal/hub/hub.go +++ b/internal/hub/hub.go @@ -1,11 +1,11 @@ package hub import ( - "log" "net/http" "sync" "github.com/gorilla/websocket" + "websocket-relay/internal/logging" "websocket-relay/internal/metrics" ) @@ -20,15 +20,17 @@ type Hub struct { unregister chan *websocket.Conn stop chan struct{} mu sync.RWMutex + logger *logging.Logger } -func New() *Hub { +func New(logger *logging.Logger) *Hub { return &Hub{ clients: make(map[*websocket.Conn]bool), broadcast: make(chan []byte), register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), stop: make(chan struct{}), + logger: logger, } } @@ -45,7 +47,7 @@ func (h *Hub) Run() { } h.mu.Unlock() metrics.ConnectedClients.Set(0) - log.Printf("Hub stopped, all clients disconnected") + h.logger.Info("Hub stopped, all clients disconnected") return case conn := <-h.register: @@ -54,7 +56,7 @@ func (h *Hub) Run() { h.mu.Unlock() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.ConnectionsTotal.Inc() - log.Printf("Client connected. Total: %d", len(h.clients)) + h.logger.Infof("Client connected. Total: %d", len(h.clients)) case conn := <-h.unregister: h.mu.Lock() @@ -65,7 +67,7 @@ func (h *Hub) Run() { h.mu.Unlock() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.DisconnectionsTotal.Inc() - log.Printf("Client disconnected. Total: %d", len(h.clients)) + h.logger.Infof("Client disconnected. Total: %d", len(h.clients)) case message := <-h.broadcast: metrics.MessagesTotal.Inc() @@ -86,7 +88,7 @@ func (h *Hub) Run() { conn.Close() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.DisconnectionsTotal.Inc() - log.Printf("Client disconnected (write error). Total: %d", len(h.clients)) + h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients)) } h.mu.Unlock() } @@ -102,7 +104,7 @@ func (h *Hub) Shutdown() { func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - log.Printf("WebSocket upgrade error: %v", err) + h.logger.Errorf("WebSocket upgrade error: %v", err) return } diff --git a/internal/hub/hub_integration_test.go b/internal/hub/hub_integration_test.go index d466689..25c6bae 100644 --- a/internal/hub/hub_integration_test.go +++ b/internal/hub/hub_integration_test.go @@ -1,6 +1,7 @@ package hub import ( + "bytes" "net/http" "net/http/httptest" "strings" @@ -9,12 +10,14 @@ import ( "time" "github.com/gorilla/websocket" + "websocket-relay/internal/logging" ) // helper: start a test server with a running Hub, return the server and hub func setupTestServer(t *testing.T) (*httptest.Server, *Hub) { t.Helper() - h := New() + logger := logging.NewLogger("debug", &bytes.Buffer{}) + h := New(logger) go h.Run() server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket)) diff --git a/internal/hub/hub_test.go b/internal/hub/hub_test.go index 59e721f..c40a705 100644 --- a/internal/hub/hub_test.go +++ b/internal/hub/hub_test.go @@ -1,12 +1,19 @@ package hub import ( + "bytes" "testing" "time" + + "websocket-relay/internal/logging" ) +func newTestLogger() *logging.Logger { + return logging.NewLogger("debug", &bytes.Buffer{}) +} + func TestNew(t *testing.T) { - h := New() + h := New(newTestLogger()) if h == nil { t.Fatal("New returned nil") } @@ -22,7 +29,7 @@ func TestNew(t *testing.T) { } func TestClientCount(t *testing.T) { - h := New() + h := New(newTestLogger()) go h.Run() defer h.Shutdown() @@ -32,7 +39,7 @@ func TestClientCount(t *testing.T) { } func TestBroadcastChannel(t *testing.T) { - h := New() + h := New(newTestLogger()) go h.Run() defer h.Shutdown() @@ -45,7 +52,7 @@ func TestBroadcastChannel(t *testing.T) { } func TestShutdown(t *testing.T) { - h := New() + h := New(newTestLogger()) done := make(chan struct{}) go func() { diff --git a/internal/logging/logging.go b/internal/logging/logging.go new file mode 100644 index 0000000..ec2f526 --- /dev/null +++ b/internal/logging/logging.go @@ -0,0 +1,126 @@ +package logging + +import ( + "fmt" + "io" + "log" + "os" + "strings" +) + +// Level represents the severity of a log message. +type Level int + +const ( + LevelDebug Level = iota + LevelInfo + LevelWarn + LevelError +) + +// Logger provides level-aware logging. +type Logger struct { + level Level + logger *log.Logger +} + +// Setup configures the global log output destination. +// Returns the opened file (if a file path was given) so the caller can defer Close. +// For "stdout", "stderr", or empty string, returns nil (no file to close). +func Setup(output string) (*os.File, error) { + switch strings.ToLower(output) { + case "", "stderr": + log.SetOutput(os.Stderr) + return nil, nil + case "stdout": + log.SetOutput(os.Stdout) + return nil, nil + default: + file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open log file %s: %w", output, err) + } + log.SetOutput(file) + return file, nil + } +} + +// NewLogger creates a new Logger with the given level and output writer. +// If level is empty or invalid, defaults to LevelInfo. +func NewLogger(level string, output io.Writer) *Logger { + return &Logger{ + level: parseLevel(level), + logger: log.New(output, "", log.LstdFlags), + } +} + +// Debug logs a message at debug level. +func (l *Logger) Debug(msg string) { + if l.level <= LevelDebug { + l.logger.Printf("[DEBUG] %s", msg) + } +} + +// Debugf logs a formatted message at debug level. +func (l *Logger) Debugf(format string, args ...interface{}) { + if l.level <= LevelDebug { + l.logger.Printf("[DEBUG] "+format, args...) + } +} + +// Info logs a message at info level. +func (l *Logger) Info(msg string) { + if l.level <= LevelInfo { + l.logger.Printf("[INFO] %s", msg) + } +} + +// Infof logs a formatted message at info level. +func (l *Logger) Infof(format string, args ...interface{}) { + if l.level <= LevelInfo { + l.logger.Printf("[INFO] "+format, args...) + } +} + +// Warn logs a message at warn level. +func (l *Logger) Warn(msg string) { + if l.level <= LevelWarn { + l.logger.Printf("[WARN] %s", msg) + } +} + +// Warnf logs a formatted message at warn level. +func (l *Logger) Warnf(format string, args ...interface{}) { + if l.level <= LevelWarn { + l.logger.Printf("[WARN] "+format, args...) + } +} + +// Error logs a message at error level. +func (l *Logger) Error(msg string) { + if l.level <= LevelError { + l.logger.Printf("[ERROR] %s", msg) + } +} + +// Errorf logs a formatted message at error level. +func (l *Logger) Errorf(format string, args ...interface{}) { + if l.level <= LevelError { + l.logger.Printf("[ERROR] "+format, args...) + } +} + +func parseLevel(level string) Level { + switch strings.ToLower(level) { + case "debug": + return LevelDebug + case "info": + return LevelInfo + case "warn": + return LevelWarn + case "error": + return LevelError + default: + return LevelInfo + } +} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go new file mode 100644 index 0000000..6303dc6 --- /dev/null +++ b/internal/logging/logging_test.go @@ -0,0 +1,278 @@ +package logging + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestSetupStdout(t *testing.T) { + file, err := Setup("stdout") + if err != nil { + t.Fatalf("Setup(stdout) failed: %v", err) + } + if file != nil { + t.Error("Expected nil file for stdout") + } +} + +func TestSetupStderr(t *testing.T) { + file, err := Setup("stderr") + if err != nil { + t.Fatalf("Setup(stderr) failed: %v", err) + } + if file != nil { + t.Error("Expected nil file for stderr") + } +} + +func TestSetupEmpty(t *testing.T) { + file, err := Setup("") + if err != nil { + t.Fatalf("Setup('') failed: %v", err) + } + if file != nil { + t.Error("Expected nil file for empty string") + } +} + +func TestSetupFilePath(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "test.log") + + file, err := Setup(logPath) + if err != nil { + t.Fatalf("Setup(file) failed: %v", err) + } + if file == nil { + t.Fatal("Expected non-nil file for file path") + } + defer file.Close() + + // Verify file was created + if _, err := os.Stat(logPath); os.IsNotExist(err) { + t.Error("Expected log file to be created") + } +} + +func TestSetupInvalidPath(t *testing.T) { + _, err := Setup("/nonexistent/directory/path/test.log") + if err == nil { + t.Error("Expected error for invalid path") + } +} + +func TestNewLoggerDefaultLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("", &buf) + + if logger.level != LevelInfo { + t.Errorf("Expected default level Info, got %d", logger.level) + } +} + +func TestNewLoggerInvalidLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("invalid", &buf) + + if logger.level != LevelInfo { + t.Errorf("Expected default level Info for invalid input, got %d", logger.level) + } +} + +func TestNewLoggerDebugLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("debug", &buf) + + if logger.level != LevelDebug { + t.Errorf("Expected level Debug, got %d", logger.level) + } +} + +func TestNewLoggerWarnLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn", &buf) + + if logger.level != LevelWarn { + t.Errorf("Expected level Warn, got %d", logger.level) + } +} + +func TestNewLoggerErrorLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + if logger.level != LevelError { + t.Errorf("Expected level Error, got %d", logger.level) + } +} + +func TestLoggerDebugOutputAtDebugLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("debug", &buf) + + logger.Debug("test message") + output := buf.String() + + if !strings.Contains(output, "[DEBUG]") { + t.Errorf("Expected [DEBUG] prefix, got: %s", output) + } + if !strings.Contains(output, "test message") { + t.Errorf("Expected 'test message' in output, got: %s", output) + } +} + +func TestLoggerDebugSuppressedAtInfoLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("info", &buf) + + logger.Debug("should not appear") + output := buf.String() + + if output != "" { + t.Errorf("Expected no output for debug at info level, got: %s", output) + } +} + +func TestLoggerInfoOutputAtInfoLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("info", &buf) + + logger.Info("info message") + output := buf.String() + + if !strings.Contains(output, "[INFO]") { + t.Errorf("Expected [INFO] prefix, got: %s", output) + } + if !strings.Contains(output, "info message") { + t.Errorf("Expected 'info message' in output, got: %s", output) + } +} + +func TestLoggerInfoSuppressedAtWarnLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn", &buf) + + logger.Info("should not appear") + output := buf.String() + + if output != "" { + t.Errorf("Expected no output for info at warn level, got: %s", output) + } +} + +func TestLoggerWarnOutputAtWarnLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn", &buf) + + logger.Warn("warn message") + output := buf.String() + + if !strings.Contains(output, "[WARN]") { + t.Errorf("Expected [WARN] prefix, got: %s", output) + } + if !strings.Contains(output, "warn message") { + t.Errorf("Expected 'warn message' in output, got: %s", output) + } +} + +func TestLoggerWarnSuppressedAtErrorLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + logger.Warn("should not appear") + output := buf.String() + + if output != "" { + t.Errorf("Expected no output for warn at error level, got: %s", output) + } +} + +func TestLoggerErrorOutputAtErrorLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + logger.Error("error message") + output := buf.String() + + if !strings.Contains(output, "[ERROR]") { + t.Errorf("Expected [ERROR] prefix, got: %s", output) + } + if !strings.Contains(output, "error message") { + t.Errorf("Expected 'error message' in output, got: %s", output) + } +} + +func TestLoggerErrorAlwaysOutputs(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + logger.Error("critical failure") + output := buf.String() + + if !strings.Contains(output, "critical failure") { + t.Errorf("Expected error message in output, got: %s", output) + } +} + +func TestLoggerDebugLevelAllMessages(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("debug", &buf) + + logger.Debug("d") + logger.Info("i") + logger.Warn("w") + logger.Error("e") + + output := buf.String() + if !strings.Contains(output, "[DEBUG]") { + t.Error("Expected [DEBUG] in output") + } + if !strings.Contains(output, "[INFO]") { + t.Error("Expected [INFO] in output") + } + if !strings.Contains(output, "[WARN]") { + t.Error("Expected [WARN] in output") + } + if !strings.Contains(output, "[ERROR]") { + t.Error("Expected [ERROR] in output") + } +} + +func TestLoggerFormatf(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("info", &buf) + + logger.Infof("count: %d", 42) + output := buf.String() + + if !strings.Contains(output, "count: 42") { + t.Errorf("Expected formatted output 'count: 42', got: %s", output) + } +} + +func TestLoggerLevelCaseInsensitive(t *testing.T) { + var buf bytes.Buffer + + logger := NewLogger("DEBUG", &buf) + if logger.level != LevelDebug { + t.Errorf("Expected level Debug for 'DEBUG', got %d", logger.level) + } + + logger = NewLogger("Info", &buf) + if logger.level != LevelInfo { + t.Errorf("Expected level Info for 'Info', got %d", logger.level) + } + + logger = NewLogger("WARN", &buf) + if logger.level != LevelWarn { + t.Errorf("Expected level Warn for 'WARN', got %d", logger.level) + } + + logger = NewLogger("ERROR", &buf) + if logger.level != LevelError { + t.Errorf("Expected level Error for 'ERROR', got %d", logger.level) + } +} diff --git a/main.go b/main.go index e0526d6..e7774be 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "websocket-relay/internal/config" "websocket-relay/internal/hub" + "websocket-relay/internal/logging" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -26,7 +27,19 @@ func main() { log.Fatal("Failed to load config:", err) } - h := hub.New() + // Setup log output destination + logFile, err := logging.Setup(cfg.Logging.Output) + if err != nil { + log.Fatal("Failed to setup logging:", err) + } + if logFile != nil { + defer logFile.Close() + } + + // Create leveled logger + logger := logging.NewLogger(cfg.Logging.Level, log.Writer()) + + h := hub.New(logger) go h.Run() // Start metrics server if enabled @@ -40,9 +53,9 @@ func main() { Handler: metricsMux, } go func() { - log.Printf("Metrics server starting on %s", metricsAddr) + logger.Infof("Metrics server starting on %s", metricsAddr) if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Printf("Metrics server error: %v", err) + logger.Errorf("Metrics server error: %v", err) } }() } @@ -59,14 +72,16 @@ func main() { // Start the main server in a goroutine go func() { if cfg.Server.TLS.Enabled { - log.Printf("WebSocket relay server starting on %s (TLS)", addr) + logger.Infof("WebSocket relay server starting on %s (TLS)", addr) if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server error: %v", err) + logger.Errorf("Server error: %v", err) + os.Exit(1) } } else { - log.Printf("WebSocket relay server starting on %s (HTTP)", addr) + logger.Infof("WebSocket relay server starting on %s (HTTP)", addr) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server error: %v", err) + logger.Errorf("Server error: %v", err) + os.Exit(1) } } }() @@ -75,7 +90,7 @@ func main() { quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) sig := <-quit - log.Printf("Received signal %v, shutting down gracefully...", sig) + logger.Infof("Received signal %v, shutting down gracefully...", sig) // Create a deadline for the shutdown ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -83,18 +98,18 @@ func main() { // Shut down the main HTTP server (stops accepting new connections) if err := server.Shutdown(ctx); err != nil { - log.Printf("HTTP server shutdown error: %v", err) + logger.Errorf("HTTP server shutdown error: %v", err) } // Shut down the metrics server if metricsServer != nil { if err := metricsServer.Shutdown(ctx); err != nil { - log.Printf("Metrics server shutdown error: %v", err) + logger.Errorf("Metrics server shutdown error: %v", err) } } // Stop the hub and close all WebSocket connections h.Shutdown() - log.Printf("Server stopped") + logger.Info("Server stopped") }