diff options
author | Michael Hunteman <michael@huntm.net> | 2024-06-23 13:55:42 -0700 |
---|---|---|
committer | Michael Hunteman <michael@huntm.net> | 2024-06-23 13:55:42 -0700 |
commit | 07752babb4e692452e1cd7f2133c4d8dde1b3b1c (patch) | |
tree | b3be7698f1af43f83bccd3bbbf6e19cd03532f1b /server/cmd | |
parent | 4bf5d1a620dfe96ea9593d44cfcd0f142fcdec61 (diff) |
Authenticate UI users
Diffstat (limited to 'server/cmd')
-rw-r--r-- | server/cmd/main.go | 83 |
1 files changed, 81 insertions, 2 deletions
diff --git a/server/cmd/main.go b/server/cmd/main.go index f886e2b..5b81b66 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -2,6 +2,8 @@ package main import ( "context" + "crypto/rand" + "encoding/base64" "encoding/json" "fmt" "log" @@ -9,6 +11,7 @@ import ( "os" "regexp" + "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5/pgxpool" "git.huntm.net/wedding/server/guests" @@ -19,6 +22,7 @@ type guestHandler struct { } type guestStore interface { + FindGuest(creds guests.Credentials) (guests.Guest, error) Get() ([]guests.Guest, error) Add(guest guests.Guest) error Update(guest guests.Guest) error @@ -41,6 +45,60 @@ func newGuestHandler(s guestStore) *guestHandler { } } +func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) { + var creds guests.Credentials + err := json.NewDecoder(r.Body).Decode(&creds) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer r.Body.Close() + + guest, err := h.store.FindGuest(creds) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + claims := &guests.Claims{ + Guest: guest, + RegisteredClaims: jwt.RegisteredClaims{}, + } + + key := make([]byte, 32) + _, err = rand.Read(key) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + secretKey := []byte(base64.StdEncoding.EncodeToString(key)) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(secretKey) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + loginResponse := &guests.LoginResponse{ + Guest: guest, + Token: tokenString, + } + + jsonBytes, err := json.Marshal(loginResponse) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) + http.SetCookie(w, &http.Cookie{ + Name: "token", + Value: tokenString, + }) + w.Write(jsonBytes) +} + func (h *guestHandler) getGuests(w http.ResponseWriter, _ *http.Request) { guests, err := h.store.Get() if err != nil { @@ -106,13 +164,17 @@ func (h *guestHandler) updateGuest(w http.ResponseWriter, r *http.Request) { err = h.store.Update(guest) if err != nil { - http.Error(w, "Guest not found", http.StatusBadRequest) + http.Error(w, "Cannot update guest", http.StatusBadRequest) return } } func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { + case r.Method == http.MethodOptions: + w.WriteHeader(http.StatusOK) + case r.Method == http.MethodPost && r.URL.Path == "/guests/login": + h.login(w, r) case r.Method == http.MethodGet && guestRe.MatchString(r.URL.Path): h.getGuests(w, r) case r.Method == http.MethodPost && guestRe.MatchString(r.URL.Path): @@ -124,6 +186,23 @@ func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func enableCors(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isPreflight(r) { + w.Header().Add("Access-Control-Allow-Methods", "*") + } + w.Header().Add("Access-Control-Allow-Origin", "*") + w.Header().Add("Access-Control-Allow-Headers", "*") + next.ServeHTTP(w, r) + }) +} + +func isPreflight(r *http.Request) bool { + return r.Method == "OPTIONS" && + r.Header.Get("Origin") != "" && + r.Header.Get("Access-Control-Request-Method") != "" +} + func main() { db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database)) if err != nil { @@ -136,5 +215,5 @@ func main() { mux := http.NewServeMux() mux.Handle("/guests/", guestHandler) - log.Fatal(http.ListenAndServe(":8080", mux)) + log.Fatal(http.ListenAndServe(":8080", enableCors(mux))) } |