package main import ( "context" "encoding/json" "fmt" "log" "net/http" "os" "regexp" "slices" "time" "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5/pgxpool" "git.huntm.net/wedding/server/guest" ) type guestHandler struct { store guestStore } type guestStore interface { FindGuest(creds guest.Credentials) (guest.Guest, error) Get() ([]guest.Guest, error) Add(guest guest.Guest) error Update(guest guest.Guest) error } var ( user = os.Getenv("USER") pass = os.Getenv("PASS") host = "localhost" port = "5432" database = "postgres" guestRe = regexp.MustCompile(`^/guest/*$`) guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`) ) func newGuestHandler(s guestStore) *guestHandler { return &guestHandler{ store: s, } } func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) { var creds guest.Credentials err := json.NewDecoder(r.Body).Decode(&creds) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer r.Body.Close() weddingGuest, err := h.store.FindGuest(creds) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } expirationTime := time.Now().Add(15 * time.Minute) claims := &guest.Claims{ Credentials: creds, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expirationTime), }, } key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(key) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } loginResponse := &guest.LoginResponse{ Guest: weddingGuest, Token: tokenString, } jsonBytes, err := json.Marshal(loginResponse) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) w.Write(jsonBytes) } func (h *guestHandler) getGuests(w http.ResponseWriter, _ *http.Request) { guests, err := h.store.Get() if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } jsonBytes, err := json.Marshal(guests) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) w.Write(jsonBytes) } func (h *guestHandler) createGuest(w http.ResponseWriter, r *http.Request) { var guest guest.Guest err := json.NewDecoder(r.Body).Decode(&guest) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer r.Body.Close() guestSlice, err := h.store.Get() if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } for _, g := range guestSlice { if g.Id == guest.Id { http.Error(w, "Id already exists", http.StatusBadRequest) return } } err = h.store.Add(guest) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) } func (h *guestHandler) updateGuest(w http.ResponseWriter, r *http.Request) { tokenString := r.Header.Get("Authorization") claims := &guest.Claims{} key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { return key, nil }) if err != nil { if err == jwt.ErrSignatureInvalid { w.WriteHeader(http.StatusUnauthorized) return } http.Error(w, err.Error(), http.StatusBadRequest) return } if !token.Valid { w.WriteHeader(http.StatusUnauthorized) return } matches := guestIdRe.FindStringSubmatch(r.URL.Path) if len(matches) < 2 { http.Error(w, "No id found", http.StatusBadRequest) return } var guest guest.Guest err = json.NewDecoder(r.Body).Decode(&guest) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer r.Body.Close() err = h.store.Update(guest) if err != nil { http.Error(w, "Cannot update guest", http.StatusBadRequest) return } } func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodOptions: w.WriteHeader(http.StatusOK) case r.Method == http.MethodPost && r.URL.Path == "/guest/login": h.login(w, r) case r.Method == http.MethodGet && guestRe.MatchString(r.URL.Path): h.getGuests(w, r) case r.Method == http.MethodPost && guestRe.MatchString(r.URL.Path): h.createGuest(w, r) case r.Method == http.MethodPut && guestIdRe.MatchString(r.URL.Path): h.updateGuest(w, r) default: w.WriteHeader(http.StatusNotFound) } } func enableCors(next http.Handler) http.Handler { allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"} allowedMethods := []string{"OPTIONS", "POST", "PUT"} return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { method := r.Header.Get("Access-Control-Request-Method") if isPreflight(r) && slices.Contains(allowedMethods, method) { w.Header().Add("Access-Control-Allow-Methods", method) } origin := r.Header.Get("Origin") if slices.Contains(allowedOrigins, origin) { w.Header().Add("Access-Control-Allow-Origin", origin) } w.Header().Add("Access-Control-Allow-Headers", "*") next.ServeHTTP(w, r) }) } func isPreflight(r *http.Request) bool { return r.Method == "OPTIONS" && r.Header.Get("Origin") != "" && r.Header.Get("Access-Control-Request-Method") != "" } func main() { db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database)) if err != nil { log.Fatal(err) } defer db.Close() store := guest.NewMemStore(db) guestHandler := newGuestHandler(store) mux := http.NewServeMux() mux.Handle("/guest/", guestHandler) log.Fatal(http.ListenAndServe(":8080", enableCors(mux))) }