diff --git a/config.example.yaml b/config.example.yaml index 47f9527..092f210 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -8,3 +8,9 @@ server: metrics: enabled: true port: 9090 + +logging: + # output: stdout, stderr, or a file path (default: stderr) + output: stderr + # level: debug, info, warn, error (default: info) + level: info diff --git a/internal/config/config.go b/internal/config/config.go index 20eff8f..f51def2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,10 @@ type Config struct { Enabled bool `yaml:"enabled"` Port int `yaml:"port"` } `yaml:"metrics"` + Logging struct { + Output string `yaml:"output"` + Level string `yaml:"level"` + } `yaml:"logging"` } func Load(filename string) (*Config, error) { @@ -29,4 +33,4 @@ func Load(filename string) (*Config, error) { var config Config err = yaml.Unmarshal(data, &config) return &config, err -} \ No newline at end of file +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 0212ba3..c523e08 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -42,4 +42,94 @@ func TestLoadFileNotFound(t *testing.T) { if err == nil { t.Error("Expected error for nonexistent file") } -} \ No newline at end of file +} + +func TestLoadLoggingOutput(t *testing.T) { + testConfig := `server: + port: 8443 +logging: + output: /var/log/relay.log + level: debug` + + tmpFile, err := os.CreateTemp("", "config_logging_*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(testConfig); err != nil { + t.Fatal(err) + } + tmpFile.Close() + + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.Logging.Output != "/var/log/relay.log" { + t.Errorf("Expected output '/var/log/relay.log', got '%s'", cfg.Logging.Output) + } + if cfg.Logging.Level != "debug" { + t.Errorf("Expected level 'debug', got '%s'", cfg.Logging.Level) + } +} + +func TestLoadLoggingDefaults(t *testing.T) { + testConfig := `server: + port: 8443` + + tmpFile, err := os.CreateTemp("", "config_defaults_*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(testConfig); err != nil { + t.Fatal(err) + } + tmpFile.Close() + + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.Logging.Output != "" { + t.Errorf("Expected empty output default, got '%s'", cfg.Logging.Output) + } + if cfg.Logging.Level != "" { + t.Errorf("Expected empty level default, got '%s'", cfg.Logging.Level) + } +} + +func TestLoadLoggingStdout(t *testing.T) { + testConfig := `server: + port: 8443 +logging: + output: stdout + level: warn` + + tmpFile, err := os.CreateTemp("", "config_stdout_*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.WriteString(testConfig); err != nil { + t.Fatal(err) + } + tmpFile.Close() + + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.Logging.Output != "stdout" { + t.Errorf("Expected output 'stdout', got '%s'", cfg.Logging.Output) + } + if cfg.Logging.Level != "warn" { + t.Errorf("Expected level 'warn', got '%s'", cfg.Logging.Level) + } +} diff --git a/internal/hub/hub.go b/internal/hub/hub.go index 23993a6..0bf8f5c 100644 --- a/internal/hub/hub.go +++ b/internal/hub/hub.go @@ -1,11 +1,11 @@ package hub import ( - "log" "net/http" "sync" "github.com/gorilla/websocket" + "websocket-relay/internal/logging" "websocket-relay/internal/metrics" ) @@ -20,15 +20,17 @@ type Hub struct { unregister chan *websocket.Conn stop chan struct{} mu sync.RWMutex + logger *logging.Logger } -func New() *Hub { +func New(logger *logging.Logger) *Hub { return &Hub{ clients: make(map[*websocket.Conn]bool), broadcast: make(chan []byte), register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), stop: make(chan struct{}), + logger: logger, } } @@ -45,7 +47,7 @@ func (h *Hub) Run() { } h.mu.Unlock() metrics.ConnectedClients.Set(0) - log.Printf("Hub stopped, all clients disconnected") + h.logger.Info("Hub stopped, all clients disconnected") return case conn := <-h.register: @@ -54,7 +56,7 @@ func (h *Hub) Run() { h.mu.Unlock() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.ConnectionsTotal.Inc() - log.Printf("Client connected. Total: %d", len(h.clients)) + h.logger.Infof("Client connected. Total: %d", len(h.clients)) case conn := <-h.unregister: h.mu.Lock() @@ -65,7 +67,7 @@ func (h *Hub) Run() { h.mu.Unlock() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.DisconnectionsTotal.Inc() - log.Printf("Client disconnected. Total: %d", len(h.clients)) + h.logger.Infof("Client disconnected. Total: %d", len(h.clients)) case message := <-h.broadcast: metrics.MessagesTotal.Inc() @@ -86,7 +88,7 @@ func (h *Hub) Run() { conn.Close() metrics.ConnectedClients.Set(float64(len(h.clients))) metrics.DisconnectionsTotal.Inc() - log.Printf("Client disconnected (write error). Total: %d", len(h.clients)) + h.logger.Warnf("Client disconnected (write error). Total: %d", len(h.clients)) } h.mu.Unlock() } @@ -102,7 +104,7 @@ func (h *Hub) Shutdown() { func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - log.Printf("WebSocket upgrade error: %v", err) + h.logger.Errorf("WebSocket upgrade error: %v", err) return } diff --git a/internal/hub/hub_integration_test.go b/internal/hub/hub_integration_test.go index d466689..25c6bae 100644 --- a/internal/hub/hub_integration_test.go +++ b/internal/hub/hub_integration_test.go @@ -1,6 +1,7 @@ package hub import ( + "bytes" "net/http" "net/http/httptest" "strings" @@ -9,12 +10,14 @@ import ( "time" "github.com/gorilla/websocket" + "websocket-relay/internal/logging" ) // helper: start a test server with a running Hub, return the server and hub func setupTestServer(t *testing.T) (*httptest.Server, *Hub) { t.Helper() - h := New() + logger := logging.NewLogger("debug", &bytes.Buffer{}) + h := New(logger) go h.Run() server := httptest.NewServer(http.HandlerFunc(h.HandleWebSocket)) diff --git a/internal/hub/hub_test.go b/internal/hub/hub_test.go index 59e721f..c40a705 100644 --- a/internal/hub/hub_test.go +++ b/internal/hub/hub_test.go @@ -1,12 +1,19 @@ package hub import ( + "bytes" "testing" "time" + + "websocket-relay/internal/logging" ) +func newTestLogger() *logging.Logger { + return logging.NewLogger("debug", &bytes.Buffer{}) +} + func TestNew(t *testing.T) { - h := New() + h := New(newTestLogger()) if h == nil { t.Fatal("New returned nil") } @@ -22,7 +29,7 @@ func TestNew(t *testing.T) { } func TestClientCount(t *testing.T) { - h := New() + h := New(newTestLogger()) go h.Run() defer h.Shutdown() @@ -32,7 +39,7 @@ func TestClientCount(t *testing.T) { } func TestBroadcastChannel(t *testing.T) { - h := New() + h := New(newTestLogger()) go h.Run() defer h.Shutdown() @@ -45,7 +52,7 @@ func TestBroadcastChannel(t *testing.T) { } func TestShutdown(t *testing.T) { - h := New() + h := New(newTestLogger()) done := make(chan struct{}) go func() { diff --git a/internal/logging/logging.go b/internal/logging/logging.go new file mode 100644 index 0000000..ec2f526 --- /dev/null +++ b/internal/logging/logging.go @@ -0,0 +1,126 @@ +package logging + +import ( + "fmt" + "io" + "log" + "os" + "strings" +) + +// Level represents the severity of a log message. +type Level int + +const ( + LevelDebug Level = iota + LevelInfo + LevelWarn + LevelError +) + +// Logger provides level-aware logging. +type Logger struct { + level Level + logger *log.Logger +} + +// Setup configures the global log output destination. +// Returns the opened file (if a file path was given) so the caller can defer Close. +// For "stdout", "stderr", or empty string, returns nil (no file to close). +func Setup(output string) (*os.File, error) { + switch strings.ToLower(output) { + case "", "stderr": + log.SetOutput(os.Stderr) + return nil, nil + case "stdout": + log.SetOutput(os.Stdout) + return nil, nil + default: + file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open log file %s: %w", output, err) + } + log.SetOutput(file) + return file, nil + } +} + +// NewLogger creates a new Logger with the given level and output writer. +// If level is empty or invalid, defaults to LevelInfo. +func NewLogger(level string, output io.Writer) *Logger { + return &Logger{ + level: parseLevel(level), + logger: log.New(output, "", log.LstdFlags), + } +} + +// Debug logs a message at debug level. +func (l *Logger) Debug(msg string) { + if l.level <= LevelDebug { + l.logger.Printf("[DEBUG] %s", msg) + } +} + +// Debugf logs a formatted message at debug level. +func (l *Logger) Debugf(format string, args ...interface{}) { + if l.level <= LevelDebug { + l.logger.Printf("[DEBUG] "+format, args...) + } +} + +// Info logs a message at info level. +func (l *Logger) Info(msg string) { + if l.level <= LevelInfo { + l.logger.Printf("[INFO] %s", msg) + } +} + +// Infof logs a formatted message at info level. +func (l *Logger) Infof(format string, args ...interface{}) { + if l.level <= LevelInfo { + l.logger.Printf("[INFO] "+format, args...) + } +} + +// Warn logs a message at warn level. +func (l *Logger) Warn(msg string) { + if l.level <= LevelWarn { + l.logger.Printf("[WARN] %s", msg) + } +} + +// Warnf logs a formatted message at warn level. +func (l *Logger) Warnf(format string, args ...interface{}) { + if l.level <= LevelWarn { + l.logger.Printf("[WARN] "+format, args...) + } +} + +// Error logs a message at error level. +func (l *Logger) Error(msg string) { + if l.level <= LevelError { + l.logger.Printf("[ERROR] %s", msg) + } +} + +// Errorf logs a formatted message at error level. +func (l *Logger) Errorf(format string, args ...interface{}) { + if l.level <= LevelError { + l.logger.Printf("[ERROR] "+format, args...) + } +} + +func parseLevel(level string) Level { + switch strings.ToLower(level) { + case "debug": + return LevelDebug + case "info": + return LevelInfo + case "warn": + return LevelWarn + case "error": + return LevelError + default: + return LevelInfo + } +} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go new file mode 100644 index 0000000..6303dc6 --- /dev/null +++ b/internal/logging/logging_test.go @@ -0,0 +1,278 @@ +package logging + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestSetupStdout(t *testing.T) { + file, err := Setup("stdout") + if err != nil { + t.Fatalf("Setup(stdout) failed: %v", err) + } + if file != nil { + t.Error("Expected nil file for stdout") + } +} + +func TestSetupStderr(t *testing.T) { + file, err := Setup("stderr") + if err != nil { + t.Fatalf("Setup(stderr) failed: %v", err) + } + if file != nil { + t.Error("Expected nil file for stderr") + } +} + +func TestSetupEmpty(t *testing.T) { + file, err := Setup("") + if err != nil { + t.Fatalf("Setup('') failed: %v", err) + } + if file != nil { + t.Error("Expected nil file for empty string") + } +} + +func TestSetupFilePath(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "test.log") + + file, err := Setup(logPath) + if err != nil { + t.Fatalf("Setup(file) failed: %v", err) + } + if file == nil { + t.Fatal("Expected non-nil file for file path") + } + defer file.Close() + + // Verify file was created + if _, err := os.Stat(logPath); os.IsNotExist(err) { + t.Error("Expected log file to be created") + } +} + +func TestSetupInvalidPath(t *testing.T) { + _, err := Setup("/nonexistent/directory/path/test.log") + if err == nil { + t.Error("Expected error for invalid path") + } +} + +func TestNewLoggerDefaultLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("", &buf) + + if logger.level != LevelInfo { + t.Errorf("Expected default level Info, got %d", logger.level) + } +} + +func TestNewLoggerInvalidLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("invalid", &buf) + + if logger.level != LevelInfo { + t.Errorf("Expected default level Info for invalid input, got %d", logger.level) + } +} + +func TestNewLoggerDebugLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("debug", &buf) + + if logger.level != LevelDebug { + t.Errorf("Expected level Debug, got %d", logger.level) + } +} + +func TestNewLoggerWarnLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn", &buf) + + if logger.level != LevelWarn { + t.Errorf("Expected level Warn, got %d", logger.level) + } +} + +func TestNewLoggerErrorLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + if logger.level != LevelError { + t.Errorf("Expected level Error, got %d", logger.level) + } +} + +func TestLoggerDebugOutputAtDebugLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("debug", &buf) + + logger.Debug("test message") + output := buf.String() + + if !strings.Contains(output, "[DEBUG]") { + t.Errorf("Expected [DEBUG] prefix, got: %s", output) + } + if !strings.Contains(output, "test message") { + t.Errorf("Expected 'test message' in output, got: %s", output) + } +} + +func TestLoggerDebugSuppressedAtInfoLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("info", &buf) + + logger.Debug("should not appear") + output := buf.String() + + if output != "" { + t.Errorf("Expected no output for debug at info level, got: %s", output) + } +} + +func TestLoggerInfoOutputAtInfoLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("info", &buf) + + logger.Info("info message") + output := buf.String() + + if !strings.Contains(output, "[INFO]") { + t.Errorf("Expected [INFO] prefix, got: %s", output) + } + if !strings.Contains(output, "info message") { + t.Errorf("Expected 'info message' in output, got: %s", output) + } +} + +func TestLoggerInfoSuppressedAtWarnLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn", &buf) + + logger.Info("should not appear") + output := buf.String() + + if output != "" { + t.Errorf("Expected no output for info at warn level, got: %s", output) + } +} + +func TestLoggerWarnOutputAtWarnLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn", &buf) + + logger.Warn("warn message") + output := buf.String() + + if !strings.Contains(output, "[WARN]") { + t.Errorf("Expected [WARN] prefix, got: %s", output) + } + if !strings.Contains(output, "warn message") { + t.Errorf("Expected 'warn message' in output, got: %s", output) + } +} + +func TestLoggerWarnSuppressedAtErrorLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + logger.Warn("should not appear") + output := buf.String() + + if output != "" { + t.Errorf("Expected no output for warn at error level, got: %s", output) + } +} + +func TestLoggerErrorOutputAtErrorLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + logger.Error("error message") + output := buf.String() + + if !strings.Contains(output, "[ERROR]") { + t.Errorf("Expected [ERROR] prefix, got: %s", output) + } + if !strings.Contains(output, "error message") { + t.Errorf("Expected 'error message' in output, got: %s", output) + } +} + +func TestLoggerErrorAlwaysOutputs(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("error", &buf) + + logger.Error("critical failure") + output := buf.String() + + if !strings.Contains(output, "critical failure") { + t.Errorf("Expected error message in output, got: %s", output) + } +} + +func TestLoggerDebugLevelAllMessages(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("debug", &buf) + + logger.Debug("d") + logger.Info("i") + logger.Warn("w") + logger.Error("e") + + output := buf.String() + if !strings.Contains(output, "[DEBUG]") { + t.Error("Expected [DEBUG] in output") + } + if !strings.Contains(output, "[INFO]") { + t.Error("Expected [INFO] in output") + } + if !strings.Contains(output, "[WARN]") { + t.Error("Expected [WARN] in output") + } + if !strings.Contains(output, "[ERROR]") { + t.Error("Expected [ERROR] in output") + } +} + +func TestLoggerFormatf(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("info", &buf) + + logger.Infof("count: %d", 42) + output := buf.String() + + if !strings.Contains(output, "count: 42") { + t.Errorf("Expected formatted output 'count: 42', got: %s", output) + } +} + +func TestLoggerLevelCaseInsensitive(t *testing.T) { + var buf bytes.Buffer + + logger := NewLogger("DEBUG", &buf) + if logger.level != LevelDebug { + t.Errorf("Expected level Debug for 'DEBUG', got %d", logger.level) + } + + logger = NewLogger("Info", &buf) + if logger.level != LevelInfo { + t.Errorf("Expected level Info for 'Info', got %d", logger.level) + } + + logger = NewLogger("WARN", &buf) + if logger.level != LevelWarn { + t.Errorf("Expected level Warn for 'WARN', got %d", logger.level) + } + + logger = NewLogger("ERROR", &buf) + if logger.level != LevelError { + t.Errorf("Expected level Error for 'ERROR', got %d", logger.level) + } +} diff --git a/main.go b/main.go index e0526d6..e7774be 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "websocket-relay/internal/config" "websocket-relay/internal/hub" + "websocket-relay/internal/logging" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -26,7 +27,19 @@ func main() { log.Fatal("Failed to load config:", err) } - h := hub.New() + // Setup log output destination + logFile, err := logging.Setup(cfg.Logging.Output) + if err != nil { + log.Fatal("Failed to setup logging:", err) + } + if logFile != nil { + defer logFile.Close() + } + + // Create leveled logger + logger := logging.NewLogger(cfg.Logging.Level, log.Writer()) + + h := hub.New(logger) go h.Run() // Start metrics server if enabled @@ -40,9 +53,9 @@ func main() { Handler: metricsMux, } go func() { - log.Printf("Metrics server starting on %s", metricsAddr) + logger.Infof("Metrics server starting on %s", metricsAddr) if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Printf("Metrics server error: %v", err) + logger.Errorf("Metrics server error: %v", err) } }() } @@ -59,14 +72,16 @@ func main() { // Start the main server in a goroutine go func() { if cfg.Server.TLS.Enabled { - log.Printf("WebSocket relay server starting on %s (TLS)", addr) + logger.Infof("WebSocket relay server starting on %s (TLS)", addr) if err := server.ListenAndServeTLS(cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server error: %v", err) + logger.Errorf("Server error: %v", err) + os.Exit(1) } } else { - log.Printf("WebSocket relay server starting on %s (HTTP)", addr) + logger.Infof("WebSocket relay server starting on %s (HTTP)", addr) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server error: %v", err) + logger.Errorf("Server error: %v", err) + os.Exit(1) } } }() @@ -75,7 +90,7 @@ func main() { quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) sig := <-quit - log.Printf("Received signal %v, shutting down gracefully...", sig) + logger.Infof("Received signal %v, shutting down gracefully...", sig) // Create a deadline for the shutdown ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -83,18 +98,18 @@ func main() { // Shut down the main HTTP server (stops accepting new connections) if err := server.Shutdown(ctx); err != nil { - log.Printf("HTTP server shutdown error: %v", err) + logger.Errorf("HTTP server shutdown error: %v", err) } // Shut down the metrics server if metricsServer != nil { if err := metricsServer.Shutdown(ctx); err != nil { - log.Printf("Metrics server shutdown error: %v", err) + logger.Errorf("Metrics server shutdown error: %v", err) } } // Stop the hub and close all WebSocket connections h.Shutdown() - log.Printf("Server stopped") + logger.Info("Server stopped") }