diff options
Diffstat (limited to 'server/cmd')
-rw-r--r-- | server/cmd/main.go | 293 |
1 files changed, 17 insertions, 276 deletions
diff --git a/server/cmd/main.go b/server/cmd/main.go index 66f31b2..439ed0b 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -2,48 +2,23 @@ package main import ( "context" - "encoding/json" - "errors" "fmt" "log" "net/http" "os" - "regexp" "slices" - "time" - "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5/pgxpool" "git.huntm.net/wedding/server/guest" ) -type guestHandler struct { - store guestStore -} - -type guestStore interface { - FindGuest(credentials guest.Credentials) (guest.Guest, error) - Get() ([]guest.Guest, error) - Add(guest guest.Guest) error - Update(guest guest.Guest) error -} - -type appError struct { - Error error - Message string - Code int -} - var ( user = os.Getenv("USER") pass = os.Getenv("PASS") host = "localhost" port = "5432" database = "postgres" - - guestRe = regexp.MustCompile(`^/guest/*$`) - guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`) ) func main() { @@ -54,272 +29,38 @@ func main() { defer db.Close() store := guest.NewMemStore(db) - guestHandler := createGuestHandler(store) + guestHandler := guest.NewGuestHandler(store) mux := http.NewServeMux() mux.Handle("/guest/", guestHandler) - log.Fatal(http.ListenAndServe(":8080", enableCors(mux))) + log.Fatal(http.ListenAndServe(":8080", enableCORS(mux))) } -func createGuestHandler(s guestStore) *guestHandler { - return &guestHandler{ - store: s, - } -} - -func enableCors(next http.Handler) http.Handler { +func enableCORS(handler http.Handler) http.Handler { allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"} allowedMethods := []string{"OPTIONS", "POST", "PUT"} - return serveHttp(next, allowedOrigins, allowedMethods) + return serveHTTP(handler, allowedOrigins, allowedMethods) } -func serveHttp(next http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - method := r.Header.Get("Access-Control-Request-Method") - if isPreflight(r) && slices.Contains(allowedMethods, method) { - w.Header().Add("Access-Control-Allow-Methods", method) +func serveHTTP(handler http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + method := request.Header.Get("Access-Control-Request-Method") + if isPreflight(request) && slices.Contains(allowedMethods, method) { + responseWriter.Header().Add("Access-Control-Allow-Methods", method) } - origin := r.Header.Get("Origin") + origin := request.Header.Get("Origin") if slices.Contains(allowedOrigins, origin) { - w.Header().Add("Access-Control-Allow-Origin", origin) + responseWriter.Header().Add("Access-Control-Allow-Origin", origin) } - w.Header().Add("Access-Control-Allow-Headers", "*") - next.ServeHTTP(w, r) + responseWriter.Header().Add("Access-Control-Allow-Headers", "*") + handler.ServeHTTP(responseWriter, request) }) } -func isPreflight(r *http.Request) bool { - return r.Method == "OPTIONS" && - r.Header.Get("Origin") != "" && - r.Header.Get("Access-Control-Request-Method") != "" -} - -func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { - switch { - case request.Method == http.MethodOptions: - responseWriter.WriteHeader(http.StatusOK) - case request.Method == http.MethodPost && request.URL.Path == "/guest/login": - guestHandler.handleLogIn(responseWriter, request) - case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): - guestHandler.handlePut(responseWriter, request) - case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): - guestHandler.handleGet(responseWriter, request) - case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): - guestHandler.handlePost(responseWriter, request) - default: - responseWriter.WriteHeader(http.StatusNotFound) - } -} - -func (guestHandler *guestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { - token, err := guestHandler.logInGuest(request) - if err != nil { - http.Error(responseWriter, err.Message, err.Code) - } else { - responseWriter.Write(token) - } -} - -func (guestHandler *guestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) { - if err := guestHandler.putGuest(request); err != nil { - http.Error(responseWriter, err.Message, err.Code) - } else { - responseWriter.WriteHeader(http.StatusOK) - } -} - -func (guestHandler *guestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) { - guests, err := guestHandler.getGuests(request) - if err != nil { - http.Error(responseWriter, err.Message, err.Code) - } else { - responseWriter.Write(guests) - } -} - -func (guestHandler *guestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) { - if err := guestHandler.postGuest(request); err != nil { - http.Error(responseWriter, err.Message, err.Code) - } else { - responseWriter.WriteHeader(http.StatusOK) - } -} - -func (guestHandler *guestHandler) logInGuest(request *http.Request) ([]byte, *appError) { - credentials, err := guestHandler.decodeCredentials(request) - if err != nil { - return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest} - } - guest, err := guestHandler.store.FindGuest(credentials) - if err != nil { - return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized} - } - expirationTime := guestHandler.setExpirationTime() - claims := guestHandler.createClaims(credentials, expirationTime) - key, err := guestHandler.readKey() - if err != nil { - return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError} - } - token, err := guestHandler.createToken(claims, key) - if err != nil { - return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError} - } - jsonBytes, err := guestHandler.marshalResponse(guest, token) - if err != nil { - return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError} - } - return jsonBytes, nil -} - -func (guestHandler *guestHandler) decodeCredentials(request *http.Request) (guest.Credentials, error) { - var credentials guest.Credentials - err := json.NewDecoder(request.Body).Decode(&credentials) - defer request.Body.Close() - return credentials, err -} - -func (guestHandler *guestHandler) setExpirationTime() time.Time { - return time.Now().Add(15 * time.Minute) -} - -func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, expirationTime time.Time) *guest.Claims { - return &guest.Claims{ - Credentials: credentials, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expirationTime), - }, - } -} - -func (guestHandler *guestHandler) readKey() ([]byte, error) { - // TODO: use properties file - return os.ReadFile("C:\\Users\\mhunt\\skey.pem") -} - -func (guestHandler *guestHandler) createToken(claims *guest.Claims, key []byte) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(key) -} - -func (guestHandler *guestHandler) marshalResponse(guest guest.Guest, token string) ([]byte, error) { - loginResponse := guestHandler.createLoginResponse(guest, token) - return json.Marshal(loginResponse) -} - -func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse { - return &guest.LoginResponse{ - Guest: weddingGuest, - Token: token, - } -} - -func (guestHandler *guestHandler) putGuest(request *http.Request) *appError { - if err := guestHandler.validateToken(request); err != nil { - return err - } - if guestHandler.findId(request) { - return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound} - } - guest, err := guestHandler.decodeGuest(request) - if err != nil { - return &appError{err, "Invalid guest", http.StatusBadRequest} - } - if err := guestHandler.store.Update(guest); err != nil { - return &appError{err, "Failed to update guest", http.StatusInternalServerError} - } - return nil -} - -func (guestHandler *guestHandler) validateToken(request *http.Request) *appError { - authorizationHeader := guestHandler.getToken(request) - claims := guestHandler.initializeClaims() - key, err := guestHandler.readKey() - if err != nil { - return &appError{err, "Failed to read secret key", http.StatusInternalServerError} - } - token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) - if err != nil { - if err == jwt.ErrSignatureInvalid { - return &appError{err, "Invalid signature", http.StatusUnauthorized} - } - return &appError{err, "Failed to parse claims", http.StatusBadRequest} - } - if !token.Valid { - return &appError{err, "Invalid token", http.StatusUnauthorized} - } - return nil -} - -func (guestHandler *guestHandler) getToken(request *http.Request) string { - return request.Header.Get("Authorization") -} - -func (guestHandler *guestHandler) initializeClaims() *guest.Claims { - return &guest.Claims{} -} - -func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) { - return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { - return key, nil - }) -} - -func (guestHandler *guestHandler) findId(request *http.Request) bool { - matches := guestIdRe.FindStringSubmatch(request.URL.Path) - return len(matches) < 2 -} - -func (guestHandler *guestHandler) decodeGuest(request *http.Request) (guest.Guest, error) { - var guest guest.Guest - err := json.NewDecoder(request.Body).Decode(&guest) - defer request.Body.Close() - return guest, err -} - -func (guestHandler *guestHandler) getGuests(request *http.Request) ([]byte, *appError) { - // TODO: check with admin token - if err := guestHandler.validateToken(request); err != nil { - return []byte{}, err - } - guests, err := guestHandler.store.Get() - if err != nil { - return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError} - } - jsonBytes, err := json.Marshal(guests) - if err != nil { - return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError} - } - return jsonBytes, nil -} - -func (guestHandler *guestHandler) postGuest(request *http.Request) *appError { - if err := guestHandler.validateToken(request); err != nil { - return err - } - guest, err := guestHandler.decodeGuest(request) - if err != nil { - return &appError{err, "Invalid guest", http.StatusBadRequest} - } - guests, err := guestHandler.store.Get() - if err != nil { - return &appError{err, "Failed to get guests", http.StatusInternalServerError} - } - if err := guestHandler.checkExistingGuests(guests, guest); err != nil { - return &appError{err, "ID already exists", http.StatusConflict} - } - if err := guestHandler.store.Add(guest); err != nil { - return &appError{err, "Failed to add guest", http.StatusInternalServerError} - } - return nil -} - -func (guestHandler *guestHandler) checkExistingGuests(guests []guest.Guest, newGuest guest.Guest) error { - for _, guest := range guests { - if guest.Id == newGuest.Id { - return errors.New("ID already exists") - } - } - return nil +func isPreflight(request *http.Request) bool { + return request.Method == "OPTIONS" && + request.Header.Get("Origin") != "" && + request.Header.Get("Access-Control-Request-Method") != "" } |