From 5fffbba3b851f6cebfd0e616bef2ff6f0c520c3d Mon Sep 17 00:00:00 2001 From: Michael Hunteman Date: Sun, 2 Feb 2025 18:57:30 -0600 Subject: Fix error handling --- server/middleware/cors.go | 36 ++++++++++++++++++++++++++++ server/middleware/log.go | 56 ++++++++++++++++++++++++++++++++++++++++++++ server/middleware/logging.go | 50 --------------------------------------- 3 files changed, 92 insertions(+), 50 deletions(-) create mode 100644 server/middleware/cors.go create mode 100644 server/middleware/log.go delete mode 100644 server/middleware/logging.go (limited to 'server/middleware') diff --git a/server/middleware/cors.go b/server/middleware/cors.go new file mode 100644 index 0000000..641113a --- /dev/null +++ b/server/middleware/cors.go @@ -0,0 +1,36 @@ +package middleware + +import "net/http" + +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeMethods(w, r) + writeOrigins(w, r) + writeHeaders(w) + next.ServeHTTP(w, r) + }) +} + +func writeMethods(responseWriter http.ResponseWriter, request *http.Request) { + method := request.Header.Get("Access-Control-Request-Method") + if isPreflight(request) { + responseWriter.Header().Add("Access-Control-Allow-Methods", method) + } +} + +func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) { + origin := request.Header.Get("Origin") + if origin == "http://localhost:5173" { + responseWriter.Header().Add("Access-Control-Allow-Origin", origin) + } +} + +func writeHeaders(responseWriter http.ResponseWriter) { + responseWriter.Header().Add("Access-Control-Allow-Headers", "*") +} + +func isPreflight(request *http.Request) bool { + return request.Method == "OPTIONS" && + request.Header.Get("Origin") != "" && + request.Header.Get("Access-Control-Request-Method") != "" +} diff --git a/server/middleware/log.go b/server/middleware/log.go new file mode 100644 index 0000000..a872daa --- /dev/null +++ b/server/middleware/log.go @@ -0,0 +1,56 @@ +package middleware + +import ( + "bytes" + "fmt" + "log/slog" + "net/http" + "os" + "time" +) + +type LoggingResponseWriter struct { + http.ResponseWriter + statusCode int + responseBody *bytes.Buffer +} + +func (w *LoggingResponseWriter) WriteHeader(code int) { + w.statusCode = code + w.ResponseWriter.WriteHeader(code) +} + +func (w *LoggingResponseWriter) Write(b []byte) (int, error) { + w.responseBody.Write(b) + return w.ResponseWriter.Write(b) +} + +func (w *LoggingResponseWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func Log(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + rw := &LoggingResponseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + responseBody: &bytes.Buffer{}, + } + + jsonHandler := slog.NewJSONHandler(os.Stderr, nil) + myslog := slog.New(jsonHandler) + next.ServeHTTP(rw, r) + + if rw.statusCode >= 400 { + myslog.Error("Request", "IP", r.RemoteAddr, "Method", r.Method, "Path", r.URL.Path, "Status", + rw.statusCode, "Duration", fmt.Sprint(time.Since(start)), "Response", rw.responseBody.String()) + } else { + myslog.Info("Request", "IP", r.RemoteAddr, "Method", r.Method, "Path", r.URL.Path, "Status", + rw.statusCode, "Duration", fmt.Sprint(time.Since(start))) + } + }) +} diff --git a/server/middleware/logging.go b/server/middleware/logging.go deleted file mode 100644 index d91a5c8..0000000 --- a/server/middleware/logging.go +++ /dev/null @@ -1,50 +0,0 @@ -package middleware - -import ( - "bytes" - "fmt" - "log/slog" - "net/http" - "os" - "time" -) - -type LoggingResponseWriter struct { - http.ResponseWriter - statusCode int - responseBody *bytes.Buffer -} - -func (w *LoggingResponseWriter) WriteHeader(code int) { - w.statusCode = code - w.ResponseWriter.WriteHeader(code) -} - -func (w *LoggingResponseWriter) Write(b []byte) (int, error) { - w.responseBody.Write(b) - return w.ResponseWriter.Write(b) -} - -func LoggingMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - - rw := &LoggingResponseWriter{ - ResponseWriter: w, - statusCode: http.StatusOK, - responseBody: &bytes.Buffer{}, - } - - jsonHandler := slog.NewJSONHandler(os.Stderr, nil) - myslog := slog.New(jsonHandler) - next.ServeHTTP(rw, r) - - if rw.statusCode >= 400 { - myslog.Error("Request", "IP", r.RemoteAddr, "Method", r.Method, "Path", r.URL.Path, "Status", - rw.statusCode, "Duration", fmt.Sprint(time.Since(start)), "Response", rw.responseBody.String()) - } else { - myslog.Info("Request", "IP", r.RemoteAddr, "Method", r.Method, "Path", r.URL.Path, "Status", - rw.statusCode, "Duration", fmt.Sprint(time.Since(start))) - } - }) -} -- cgit v1.2.3