summaryrefslogtreecommitdiff
path: root/server/cmd/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/cmd/main.go')
-rw-r--r--server/cmd/main.go293
1 files changed, 17 insertions, 276 deletions
diff --git a/server/cmd/main.go b/server/cmd/main.go
index 66f31b2..439ed0b 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -2,48 +2,23 @@ package main
import (
"context"
- "encoding/json"
- "errors"
"fmt"
"log"
"net/http"
"os"
- "regexp"
"slices"
- "time"
- "github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgxpool"
"git.huntm.net/wedding/server/guest"
)
-type guestHandler struct {
- store guestStore
-}
-
-type guestStore interface {
- FindGuest(credentials guest.Credentials) (guest.Guest, error)
- Get() ([]guest.Guest, error)
- Add(guest guest.Guest) error
- Update(guest guest.Guest) error
-}
-
-type appError struct {
- Error error
- Message string
- Code int
-}
-
var (
user = os.Getenv("USER")
pass = os.Getenv("PASS")
host = "localhost"
port = "5432"
database = "postgres"
-
- guestRe = regexp.MustCompile(`^/guest/*$`)
- guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
)
func main() {
@@ -54,272 +29,38 @@ func main() {
defer db.Close()
store := guest.NewMemStore(db)
- guestHandler := createGuestHandler(store)
+ guestHandler := guest.NewGuestHandler(store)
mux := http.NewServeMux()
mux.Handle("/guest/", guestHandler)
- log.Fatal(http.ListenAndServe(":8080", enableCors(mux)))
+ log.Fatal(http.ListenAndServe(":8080", enableCORS(mux)))
}
-func createGuestHandler(s guestStore) *guestHandler {
- return &guestHandler{
- store: s,
- }
-}
-
-func enableCors(next http.Handler) http.Handler {
+func enableCORS(handler http.Handler) http.Handler {
allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
allowedMethods := []string{"OPTIONS", "POST", "PUT"}
- return serveHttp(next, allowedOrigins, allowedMethods)
+ return serveHTTP(handler, allowedOrigins, allowedMethods)
}
-func serveHttp(next http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- method := r.Header.Get("Access-Control-Request-Method")
- if isPreflight(r) && slices.Contains(allowedMethods, method) {
- w.Header().Add("Access-Control-Allow-Methods", method)
+func serveHTTP(handler http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler {
+ return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) {
+ method := request.Header.Get("Access-Control-Request-Method")
+ if isPreflight(request) && slices.Contains(allowedMethods, method) {
+ responseWriter.Header().Add("Access-Control-Allow-Methods", method)
}
- origin := r.Header.Get("Origin")
+ origin := request.Header.Get("Origin")
if slices.Contains(allowedOrigins, origin) {
- w.Header().Add("Access-Control-Allow-Origin", origin)
+ responseWriter.Header().Add("Access-Control-Allow-Origin", origin)
}
- w.Header().Add("Access-Control-Allow-Headers", "*")
- next.ServeHTTP(w, r)
+ responseWriter.Header().Add("Access-Control-Allow-Headers", "*")
+ handler.ServeHTTP(responseWriter, request)
})
}
-func isPreflight(r *http.Request) bool {
- return r.Method == "OPTIONS" &&
- r.Header.Get("Origin") != "" &&
- r.Header.Get("Access-Control-Request-Method") != ""
-}
-
-func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
- switch {
- case request.Method == http.MethodOptions:
- responseWriter.WriteHeader(http.StatusOK)
- case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
- guestHandler.handleLogIn(responseWriter, request)
- case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
- guestHandler.handlePut(responseWriter, request)
- case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
- guestHandler.handleGet(responseWriter, request)
- case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
- guestHandler.handlePost(responseWriter, request)
- default:
- responseWriter.WriteHeader(http.StatusNotFound)
- }
-}
-
-func (guestHandler *guestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
- token, err := guestHandler.logInGuest(request)
- if err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.Write(token)
- }
-}
-
-func (guestHandler *guestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) {
- if err := guestHandler.putGuest(request); err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.WriteHeader(http.StatusOK)
- }
-}
-
-func (guestHandler *guestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) {
- guests, err := guestHandler.getGuests(request)
- if err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.Write(guests)
- }
-}
-
-func (guestHandler *guestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) {
- if err := guestHandler.postGuest(request); err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.WriteHeader(http.StatusOK)
- }
-}
-
-func (guestHandler *guestHandler) logInGuest(request *http.Request) ([]byte, *appError) {
- credentials, err := guestHandler.decodeCredentials(request)
- if err != nil {
- return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest}
- }
- guest, err := guestHandler.store.FindGuest(credentials)
- if err != nil {
- return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized}
- }
- expirationTime := guestHandler.setExpirationTime()
- claims := guestHandler.createClaims(credentials, expirationTime)
- key, err := guestHandler.readKey()
- if err != nil {
- return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError}
- }
- token, err := guestHandler.createToken(claims, key)
- if err != nil {
- return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError}
- }
- jsonBytes, err := guestHandler.marshalResponse(guest, token)
- if err != nil {
- return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError}
- }
- return jsonBytes, nil
-}
-
-func (guestHandler *guestHandler) decodeCredentials(request *http.Request) (guest.Credentials, error) {
- var credentials guest.Credentials
- err := json.NewDecoder(request.Body).Decode(&credentials)
- defer request.Body.Close()
- return credentials, err
-}
-
-func (guestHandler *guestHandler) setExpirationTime() time.Time {
- return time.Now().Add(15 * time.Minute)
-}
-
-func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, expirationTime time.Time) *guest.Claims {
- return &guest.Claims{
- Credentials: credentials,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(expirationTime),
- },
- }
-}
-
-func (guestHandler *guestHandler) readKey() ([]byte, error) {
- // TODO: use properties file
- return os.ReadFile("C:\\Users\\mhunt\\skey.pem")
-}
-
-func (guestHandler *guestHandler) createToken(claims *guest.Claims, key []byte) (string, error) {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString(key)
-}
-
-func (guestHandler *guestHandler) marshalResponse(guest guest.Guest, token string) ([]byte, error) {
- loginResponse := guestHandler.createLoginResponse(guest, token)
- return json.Marshal(loginResponse)
-}
-
-func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse {
- return &guest.LoginResponse{
- Guest: weddingGuest,
- Token: token,
- }
-}
-
-func (guestHandler *guestHandler) putGuest(request *http.Request) *appError {
- if err := guestHandler.validateToken(request); err != nil {
- return err
- }
- if guestHandler.findId(request) {
- return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound}
- }
- guest, err := guestHandler.decodeGuest(request)
- if err != nil {
- return &appError{err, "Invalid guest", http.StatusBadRequest}
- }
- if err := guestHandler.store.Update(guest); err != nil {
- return &appError{err, "Failed to update guest", http.StatusInternalServerError}
- }
- return nil
-}
-
-func (guestHandler *guestHandler) validateToken(request *http.Request) *appError {
- authorizationHeader := guestHandler.getToken(request)
- claims := guestHandler.initializeClaims()
- key, err := guestHandler.readKey()
- if err != nil {
- return &appError{err, "Failed to read secret key", http.StatusInternalServerError}
- }
- token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key)
- if err != nil {
- if err == jwt.ErrSignatureInvalid {
- return &appError{err, "Invalid signature", http.StatusUnauthorized}
- }
- return &appError{err, "Failed to parse claims", http.StatusBadRequest}
- }
- if !token.Valid {
- return &appError{err, "Invalid token", http.StatusUnauthorized}
- }
- return nil
-}
-
-func (guestHandler *guestHandler) getToken(request *http.Request) string {
- return request.Header.Get("Authorization")
-}
-
-func (guestHandler *guestHandler) initializeClaims() *guest.Claims {
- return &guest.Claims{}
-}
-
-func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) {
- return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
- return key, nil
- })
-}
-
-func (guestHandler *guestHandler) findId(request *http.Request) bool {
- matches := guestIdRe.FindStringSubmatch(request.URL.Path)
- return len(matches) < 2
-}
-
-func (guestHandler *guestHandler) decodeGuest(request *http.Request) (guest.Guest, error) {
- var guest guest.Guest
- err := json.NewDecoder(request.Body).Decode(&guest)
- defer request.Body.Close()
- return guest, err
-}
-
-func (guestHandler *guestHandler) getGuests(request *http.Request) ([]byte, *appError) {
- // TODO: check with admin token
- if err := guestHandler.validateToken(request); err != nil {
- return []byte{}, err
- }
- guests, err := guestHandler.store.Get()
- if err != nil {
- return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError}
- }
- jsonBytes, err := json.Marshal(guests)
- if err != nil {
- return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError}
- }
- return jsonBytes, nil
-}
-
-func (guestHandler *guestHandler) postGuest(request *http.Request) *appError {
- if err := guestHandler.validateToken(request); err != nil {
- return err
- }
- guest, err := guestHandler.decodeGuest(request)
- if err != nil {
- return &appError{err, "Invalid guest", http.StatusBadRequest}
- }
- guests, err := guestHandler.store.Get()
- if err != nil {
- return &appError{err, "Failed to get guests", http.StatusInternalServerError}
- }
- if err := guestHandler.checkExistingGuests(guests, guest); err != nil {
- return &appError{err, "ID already exists", http.StatusConflict}
- }
- if err := guestHandler.store.Add(guest); err != nil {
- return &appError{err, "Failed to add guest", http.StatusInternalServerError}
- }
- return nil
-}
-
-func (guestHandler *guestHandler) checkExistingGuests(guests []guest.Guest, newGuest guest.Guest) error {
- for _, guest := range guests {
- if guest.Id == newGuest.Id {
- return errors.New("ID already exists")
- }
- }
- return nil
+func isPreflight(request *http.Request) bool {
+ return request.Method == "OPTIONS" &&
+ request.Header.Get("Origin") != "" &&
+ request.Header.Get("Access-Control-Request-Method") != ""
}