summaryrefslogtreecommitdiff
path: root/server/cmd
diff options
context:
space:
mode:
Diffstat (limited to 'server/cmd')
-rw-r--r--server/cmd/main.go51
1 files changed, 30 insertions, 21 deletions
diff --git a/server/cmd/main.go b/server/cmd/main.go
index b4b1c6d..83a64af 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -8,12 +8,13 @@ import (
"net/http"
"os"
"regexp"
+ "slices"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgxpool"
- "git.huntm.net/wedding/server/guests"
+ "git.huntm.net/wedding/server/guest"
)
type guestHandler struct {
@@ -21,10 +22,10 @@ type guestHandler struct {
}
type guestStore interface {
- FindGuest(creds guests.Credentials) (guests.Guest, error)
- Get() ([]guests.Guest, error)
- Add(guest guests.Guest) error
- Update(guest guests.Guest) error
+ FindGuest(creds guest.Credentials) (guest.Guest, error)
+ Get() ([]guest.Guest, error)
+ Add(guest guest.Guest) error
+ Update(guest guest.Guest) error
}
var (
@@ -34,8 +35,8 @@ var (
port = "5432"
database = "postgres"
- guestRe = regexp.MustCompile(`^/guests/*$`)
- guestIdRe = regexp.MustCompile(`^/guests/([0-9]+)$`)
+ guestRe = regexp.MustCompile(`^/guest/*$`)
+ guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
)
func newGuestHandler(s guestStore) *guestHandler {
@@ -45,7 +46,7 @@ func newGuestHandler(s guestStore) *guestHandler {
}
func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) {
- var creds guests.Credentials
+ var creds guest.Credentials
err := json.NewDecoder(r.Body).Decode(&creds)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -53,14 +54,14 @@ func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) {
}
defer r.Body.Close()
- guest, err := h.store.FindGuest(creds)
+ weddingGuest, err := h.store.FindGuest(creds)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
expirationTime := time.Now().Add(15 * time.Minute)
- claims := &guests.Claims{
+ claims := &guest.Claims{
Credentials: creds,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
@@ -80,8 +81,8 @@ func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) {
return
}
- loginResponse := &guests.LoginResponse{
- Guest: guest,
+ loginResponse := &guest.LoginResponse{
+ Guest: weddingGuest,
Token: tokenString,
}
@@ -113,7 +114,7 @@ func (h *guestHandler) getGuests(w http.ResponseWriter, _ *http.Request) {
}
func (h *guestHandler) createGuest(w http.ResponseWriter, r *http.Request) {
- var guest guests.Guest
+ var guest guest.Guest
err := json.NewDecoder(r.Body).Decode(&guest)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -145,7 +146,7 @@ func (h *guestHandler) createGuest(w http.ResponseWriter, r *http.Request) {
func (h *guestHandler) updateGuest(w http.ResponseWriter, r *http.Request) {
tokenString := r.Header.Get("Authorization")
- claims := &guests.Claims{}
+ claims := &guest.Claims{}
key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem")
if err != nil {
@@ -175,7 +176,7 @@ func (h *guestHandler) updateGuest(w http.ResponseWriter, r *http.Request) {
return
}
- var guest guests.Guest
+ var guest guest.Guest
err = json.NewDecoder(r.Body).Decode(&guest)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -194,7 +195,7 @@ func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodOptions:
w.WriteHeader(http.StatusOK)
- case r.Method == http.MethodPost && r.URL.Path == "/guests/login":
+ case r.Method == http.MethodPost && r.URL.Path == "/guest/login":
h.login(w, r)
case r.Method == http.MethodGet && guestRe.MatchString(r.URL.Path):
h.getGuests(w, r)
@@ -208,11 +209,19 @@ func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func enableCors(next http.Handler) http.Handler {
+ allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
+ allowedMethods := []string{"OPTIONS", "POST", "PUT"}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if isPreflight(r) {
- w.Header().Add("Access-Control-Allow-Methods", "*")
+ method := r.Header.Get("Method")
+ if isPreflight(r) && slices.Contains(allowedMethods, method) {
+ w.Header().Add("Access-Control-Allow-Methods", method)
}
- w.Header().Add("Access-Control-Allow-Origin", "*")
+
+ origin := r.Header.Get("Origin")
+ if slices.Contains(allowedOrigins, origin) {
+ w.Header().Add("Access-Control-Allow-Origin", origin)
+ }
+
w.Header().Add("Access-Control-Allow-Headers", "*")
next.ServeHTTP(w, r)
})
@@ -231,10 +240,10 @@ func main() {
}
defer db.Close()
- store := guests.NewMemStore(db)
+ store := guest.NewMemStore(db)
guestHandler := newGuestHandler(store)
mux := http.NewServeMux()
- mux.Handle("/guests/", guestHandler)
+ mux.Handle("/guest/", guestHandler)
log.Fatal(http.ListenAndServe(":8080", enableCors(mux)))
}