diff options
Diffstat (limited to 'server/middleware')
-rw-r--r-- | server/middleware/cors.go | 36 | ||||
-rw-r--r-- | server/middleware/log.go (renamed from server/middleware/logging.go) | 8 |
2 files changed, 43 insertions, 1 deletions
diff --git a/server/middleware/cors.go b/server/middleware/cors.go new file mode 100644 index 0000000..641113a --- /dev/null +++ b/server/middleware/cors.go @@ -0,0 +1,36 @@ +package middleware + +import "net/http" + +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeMethods(w, r) + writeOrigins(w, r) + writeHeaders(w) + next.ServeHTTP(w, r) + }) +} + +func writeMethods(responseWriter http.ResponseWriter, request *http.Request) { + method := request.Header.Get("Access-Control-Request-Method") + if isPreflight(request) { + responseWriter.Header().Add("Access-Control-Allow-Methods", method) + } +} + +func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) { + origin := request.Header.Get("Origin") + if origin == "http://localhost:5173" { + responseWriter.Header().Add("Access-Control-Allow-Origin", origin) + } +} + +func writeHeaders(responseWriter http.ResponseWriter) { + responseWriter.Header().Add("Access-Control-Allow-Headers", "*") +} + +func isPreflight(request *http.Request) bool { + return request.Method == "OPTIONS" && + request.Header.Get("Origin") != "" && + request.Header.Get("Access-Control-Request-Method") != "" +} diff --git a/server/middleware/logging.go b/server/middleware/log.go index d91a5c8..a872daa 100644 --- a/server/middleware/logging.go +++ b/server/middleware/log.go @@ -25,7 +25,13 @@ func (w *LoggingResponseWriter) Write(b []byte) (int, error) { return w.ResponseWriter.Write(b) } -func LoggingMiddleware(next http.Handler) http.Handler { +func (w *LoggingResponseWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func Log(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() |