package main import ( "context" "encoding/json" "errors" "fmt" "log" "net/http" "os" "regexp" "slices" "time" "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5/pgxpool" "git.huntm.net/wedding/server/guest" ) type guestHandler struct { store guestStore } type guestStore interface { FindGuest(credentials guest.Credentials) (guest.Guest, error) Get() ([]guest.Guest, error) Add(guest guest.Guest) error Update(guest guest.Guest) error } type appError struct { Error error Message string Code int } var ( user = os.Getenv("USER") pass = os.Getenv("PASS") host = "localhost" port = "5432" database = "postgres" guestRe = regexp.MustCompile(`^/guest/*$`) guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`) ) func main() { db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database)) if err != nil { log.Fatal(err) } defer db.Close() store := guest.NewMemStore(db) guestHandler := createGuestHandler(store) mux := http.NewServeMux() mux.Handle("/guest/", guestHandler) log.Fatal(http.ListenAndServe(":8080", enableCors(mux))) } func createGuestHandler(s guestStore) *guestHandler { return &guestHandler{ store: s, } } func enableCors(next http.Handler) http.Handler { allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"} allowedMethods := []string{"OPTIONS", "POST", "PUT"} return serveHttp(next, allowedOrigins, allowedMethods) } func serveHttp(next http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { method := r.Header.Get("Access-Control-Request-Method") if isPreflight(r) && slices.Contains(allowedMethods, method) { w.Header().Add("Access-Control-Allow-Methods", method) } origin := r.Header.Get("Origin") if slices.Contains(allowedOrigins, origin) { w.Header().Add("Access-Control-Allow-Origin", origin) } w.Header().Add("Access-Control-Allow-Headers", "*") next.ServeHTTP(w, r) }) } func isPreflight(r *http.Request) bool { return r.Method == "OPTIONS" && r.Header.Get("Origin") != "" && r.Header.Get("Access-Control-Request-Method") != "" } func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { switch { case request.Method == http.MethodOptions: responseWriter.WriteHeader(http.StatusOK) case request.Method == http.MethodPost && request.URL.Path == "/guest/login": token, err := guestHandler.login(request) if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { responseWriter.Write(token) } case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): if err := guestHandler.putGuest(request); err != nil { http.Error(responseWriter, err.Message, err.Code) } case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): guests, err := guestHandler.getGuests() if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { responseWriter.Write(guests) } case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): if err := guestHandler.postGuest(request); err != nil { http.Error(responseWriter, err.Message, err.Code) } else { responseWriter.WriteHeader(http.StatusOK) } default: responseWriter.WriteHeader(http.StatusNotFound) } } func (guestHandler *guestHandler) login(request *http.Request) ([]byte, *appError) { credentials, err := guestHandler.decodeCredentials(request) if err != nil { return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest} } guest, err := guestHandler.store.FindGuest(credentials) if err != nil { return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized} } expirationTime := guestHandler.setExpirationTime() claims := guestHandler.createClaims(credentials, expirationTime) key, err := guestHandler.readKey() if err != nil { return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError} } token, err := guestHandler.createToken(claims, key) if err != nil { return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError} } jsonBytes, err := guestHandler.marshalResponse(guest, token) if err != nil { return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError} } return jsonBytes, nil } func (guestHandler *guestHandler) decodeCredentials(request *http.Request) (guest.Credentials, error) { var credentials guest.Credentials err := json.NewDecoder(request.Body).Decode(&credentials) defer request.Body.Close() return credentials, err } func (guestHandler *guestHandler) setExpirationTime() time.Time { return time.Now().Add(15 * time.Minute) } func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, expirationTime time.Time) *guest.Claims { return &guest.Claims{ Credentials: credentials, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expirationTime), }, } } func (guestHandler *guestHandler) readKey() ([]byte, error) { return os.ReadFile("C:\\Users\\mhunt\\skey.pem") } func (guestHandler *guestHandler) createToken(claims *guest.Claims, key []byte) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(key) } func (guestHandler *guestHandler) marshalResponse(weddingGuest guest.Guest, token string) ([]byte, error) { loginResponse := guestHandler.createLoginResponse(weddingGuest, token) return json.Marshal(loginResponse) } func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse { return &guest.LoginResponse{ Guest: weddingGuest, Token: token, } } func (guestHandler *guestHandler) putGuest(request *http.Request) *appError { authorizationHeader := guestHandler.getToken(request) claims := guestHandler.initializeClaims() key, err := guestHandler.readKey() if err != nil { return &appError{err, "Failed to read secret key", http.StatusInternalServerError} } token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) if err != nil { if err == jwt.ErrSignatureInvalid { return &appError{err, "Invalid signature", http.StatusUnauthorized} } return &appError{err, "Failed to parse claims", http.StatusBadRequest} } if !token.Valid { return &appError{err, "Invalid token", http.StatusUnauthorized} } if guestHandler.findId(request) { return &appError{err, "ID not found", http.StatusNotFound} } guest, err := guestHandler.decodeGuest(request) if err != nil { return &appError{err, "Invalid guest", http.StatusBadRequest} } if guestHandler.store.Update(guest) != nil { return &appError{err, "Failed to update guest", http.StatusInternalServerError} } return nil } func (guestHandler *guestHandler) getToken(request *http.Request) string { return request.Header.Get("Authorization") } func (guestHandler *guestHandler) initializeClaims() *guest.Claims { return &guest.Claims{} } func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) { return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { return key, nil }) } func (guestHandler *guestHandler) findId(request *http.Request) bool { matches := guestIdRe.FindStringSubmatch(request.URL.Path) return len(matches) < 2 } func (guestHandler *guestHandler) decodeGuest(request *http.Request) (guest.Guest, error) { var guest guest.Guest err := json.NewDecoder(request.Body).Decode(&guest) defer request.Body.Close() return guest, err } func (guestHandler *guestHandler) getGuests() ([]byte, *appError) { guests, err := guestHandler.store.Get() if err != nil { return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError} } jsonBytes, err := json.Marshal(guests) if err != nil { return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError} } return jsonBytes, nil } func (guestHandler *guestHandler) postGuest(request *http.Request) *appError { guest, err := guestHandler.decodeGuest(request) if err != nil { return &appError{err, "Invalid guest", http.StatusBadRequest} } guests, err := guestHandler.store.Get() if err != nil { return &appError{err, "Failed to get guests", http.StatusInternalServerError} } if err := guestHandler.checkExistingGuests(guests, guest); err != nil { return &appError{err, "ID already exists", http.StatusConflict} } if err := guestHandler.store.Add(guest); err != nil { return &appError{err, "Failed to add guest", http.StatusInternalServerError} } return nil } func (guestHandler *guestHandler) checkExistingGuests(guests []guest.Guest, newGuest guest.Guest) error { for _, guest := range guests { if guest.Id == newGuest.Id { return errors.New("ID already exists") } } return nil }