diff options
Diffstat (limited to 'server/guest')
-rw-r--r-- | server/guest/handler.go | 156 | ||||
-rw-r--r-- | server/guest/store.go | 20 |
2 files changed, 115 insertions, 61 deletions
diff --git a/server/guest/handler.go b/server/guest/handler.go index 418c223..a14a039 100644 --- a/server/guest/handler.go +++ b/server/guest/handler.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "regexp" + "strconv" "time" "github.com/golang-jwt/jwt/v5" @@ -17,7 +18,7 @@ var ( ) type GuestHandler struct { - guestStore GuestStore + store GuestStore } type GuestStore interface { @@ -25,6 +26,7 @@ type GuestStore interface { Get() ([]Guest, error) Add(guest Guest) error Update(guest Guest) error + Delete(id int) error } type appError struct { @@ -39,25 +41,27 @@ func NewGuestHandler(guestStore GuestStore) *GuestHandler { } } -func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { +func (handler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { switch { case request.Method == http.MethodOptions: responseWriter.WriteHeader(http.StatusOK) case request.Method == http.MethodPost && request.URL.Path == "/guests/login": - guestHandler.handleLogIn(responseWriter, request) + handler.handleLogIn(responseWriter, request) case request.Method == http.MethodPut && guestIDRegex.MatchString(request.URL.Path): - guestHandler.handlePut(responseWriter, request) + handler.handlePut(responseWriter, request) case request.Method == http.MethodGet && guestRegex.MatchString(request.URL.Path): - guestHandler.handleGet(responseWriter, request) - case request.Method == http.MethodPost && guestIDRegex.MatchString(request.URL.Path): - guestHandler.handlePost(responseWriter, request) + handler.handleGet(responseWriter, request) + case request.Method == http.MethodPost && guestRegex.MatchString(request.URL.Path): + handler.handlePost(responseWriter, request) + case request.Method == http.MethodDelete && guestIDRegex.MatchString(request.URL.Path): + handler.handleDelete(responseWriter, request) default: responseWriter.WriteHeader(http.StatusNotFound) } } -func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { - token, err := guestHandler.logIn(request) +func (handler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { + token, err := handler.logIn(request) if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { @@ -65,16 +69,16 @@ func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter } } -func (guestHandler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) { - if err := guestHandler.putGuest(request); err != nil { +func (handler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) { + if err := handler.putGuest(request); err != nil { http.Error(responseWriter, err.Message, err.Code) } else { responseWriter.WriteHeader(http.StatusOK) } } -func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) { - guests, err := guestHandler.getGuests(request) +func (handler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) { + guests, err := handler.getGuests(request) if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { @@ -82,52 +86,60 @@ func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, } } -func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) { - if err := guestHandler.postGuest(request); err != nil { +func (handler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) { + if err := handler.postGuest(request); err != nil { http.Error(responseWriter, err.Message, err.Code) } else { responseWriter.WriteHeader(http.StatusOK) } } -func (guestHandler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) { - credentials, err := guestHandler.decodeCredentials(request) +func (handler *GuestHandler) handleDelete(responseWriter http.ResponseWriter, request *http.Request) { + if err := handler.deleteGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.WriteHeader(http.StatusOK) + } +} + +func (handler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) { + credentials, err := handler.decodeCredentials(request) if err != nil { return []byte{}, &appError{err, "failed to unmarshal credentials", http.StatusBadRequest} } - guest, err := guestHandler.guestStore.Find(credentials) + guest, err := handler.store.Find(credentials) if err != nil { return []byte{}, &appError{err, "guest not found", http.StatusUnauthorized} } - expirationTime := guestHandler.setExpirationTime() - claims := guestHandler.createClaims(credentials, expirationTime) - key, err := guestHandler.readGuestKey() + expirationTime := handler.setExpirationTime() + claims := handler.createClaims(credentials, expirationTime) + key, err := handler.readGuestKey() if err != nil { return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} } - token, err := guestHandler.createToken(claims, key) + token, err := handler.createToken(claims, key) if err != nil { return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError} } - jsonBytes, err := guestHandler.marshalResponse(guest, token) + jsonBytes, err := handler.marshalResponse(guest, token) if err != nil { return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError} } return jsonBytes, nil } -func (guestHandler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) { +func (handler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) { var credentials Credentials err := json.NewDecoder(request.Body).Decode(&credentials) defer request.Body.Close() return credentials, err } -func (guestHandler *GuestHandler) setExpirationTime() time.Time { +func (handler *GuestHandler) setExpirationTime() time.Time { return time.Now().Add(15 * time.Minute) } -func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims { +func (handler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims { return &Claims{ Credentials: credentials, RegisteredClaims: jwt.RegisteredClaims{ @@ -136,57 +148,57 @@ func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirati } } -func (guestHandler *GuestHandler) readGuestKey() ([]byte, error) { +func (handler *GuestHandler) readGuestKey() ([]byte, error) { // TODO: use properties file return os.ReadFile("C:\\Users\\mhunt\\guest.pem") } -func (guestHandler *GuestHandler) readAdminKey() ([]byte, error) { +func (handler *GuestHandler) readAdminKey() ([]byte, error) { return os.ReadFile("C:\\Users\\mhunt\\admin.pem") } -func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { +func (handler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(key) } -func (guestHandler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) { - loginResponse := guestHandler.createLoginResponse(guest, token) +func (handler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) { + loginResponse := handler.createLoginResponse(guest, token) return json.Marshal(loginResponse) } -func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse { +func (handler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse { return &LoginResponse{ Guest: weddingGuest, Token: token, } } -func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError { - guestKey, err := guestHandler.readGuestKey() +func (handler *GuestHandler) putGuest(request *http.Request) *appError { + guestKey, err := handler.readGuestKey() if err != nil { return &appError{err, "failed to read secret key", http.StatusInternalServerError} } - if err := guestHandler.validateToken(request, guestKey); err != nil { + if err := handler.validateToken(request, guestKey); err != nil { return err } - if guestHandler.findID(request) { + if handler.findID(request) { return &appError{errors.New("id not found"), "id not found", http.StatusNotFound} } - guest, err := guestHandler.decodeGuest(request) + guest, err := handler.decodeGuest(request) if err != nil { return &appError{err, "invalid guest", http.StatusBadRequest} } - if err := guestHandler.guestStore.Update(guest); err != nil { + if err := handler.store.Update(guest); err != nil { return &appError{err, "failed to update guest", http.StatusInternalServerError} } return nil } -func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byte) *appError { - authorizationHeader := guestHandler.getToken(request) - claims := guestHandler.newClaims() - token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) +func (handler *GuestHandler) validateToken(request *http.Request, key []byte) *appError { + authorizationHeader := handler.getToken(request) + claims := handler.newClaims() + token, err := handler.parseWithClaims(authorizationHeader, claims, key) if err != nil { if err == jwt.ErrSignatureInvalid { return &appError{err, "invalid signature", http.StatusUnauthorized} @@ -199,41 +211,41 @@ func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byt return nil } -func (guestHandler *GuestHandler) getToken(request *http.Request) string { +func (handler *GuestHandler) getToken(request *http.Request) string { return request.Header.Get("Authorization") } -func (guestHandler *GuestHandler) newClaims() *Claims { +func (handler *GuestHandler) newClaims() *Claims { return &Claims{} } -func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) { +func (handler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) { return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { return key, nil }) } -func (guestHandler *GuestHandler) findID(request *http.Request) bool { +func (handler *GuestHandler) findID(request *http.Request) bool { matches := guestIDRegex.FindStringSubmatch(request.URL.Path) return len(matches) < 2 } -func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) { +func (handler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) { var guest Guest err := json.NewDecoder(request.Body).Decode(&guest) defer request.Body.Close() return guest, err } -func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) { - adminKey, err := guestHandler.readAdminKey() +func (handler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) { + adminKey, err := handler.readAdminKey() if err != nil { return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} } - if err := guestHandler.validateToken(request, adminKey); err != nil { + if err := handler.validateToken(request, adminKey); err != nil { return []byte{}, err } - guests, err := guestHandler.guestStore.Get() + guests, err := handler.store.Get() if err != nil { return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError} } @@ -244,32 +256,32 @@ func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *app return jsonBytes, nil } -func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError { - adminKey, err := guestHandler.readAdminKey() +func (handler *GuestHandler) postGuest(request *http.Request) *appError { + adminKey, err := handler.readAdminKey() if err != nil { return &appError{err, "failed to read secret key", http.StatusInternalServerError} } - if err := guestHandler.validateToken(request, adminKey); err != nil { + if err := handler.validateToken(request, adminKey); err != nil { return err } - guest, err := guestHandler.decodeGuest(request) + guest, err := handler.decodeGuest(request) if err != nil { return &appError{err, "invalid guest", http.StatusBadRequest} } - guests, err := guestHandler.guestStore.Get() + guests, err := handler.store.Get() if err != nil { return &appError{err, "failed to get guests", http.StatusInternalServerError} } - if err := guestHandler.checkExistingGuests(guests, guest); err != nil { + if err := handler.checkExistingGuests(guests, guest); err != nil { return &appError{err, "id already exists", http.StatusConflict} } - if err := guestHandler.guestStore.Add(guest); err != nil { + if err := handler.store.Add(guest); err != nil { return &appError{err, "failed to add guest", http.StatusInternalServerError} } return nil } -func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error { +func (handler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error { for _, guest := range guests { if guest.ID == newGuest.ID { return errors.New("id already exists") @@ -277,3 +289,29 @@ func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest G } return nil } + +func (handler *GuestHandler) deleteGuest(request *http.Request) *appError { + adminKey, err := handler.readAdminKey() + if err != nil { + return &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := handler.validateToken(request, adminKey); err != nil { + return err + } + if handler.findID(request) { + return &appError{errors.New("id not found"), "id not found", http.StatusNotFound} + } + guestID, err := getID(request) + if err != nil { + return &appError{err, "failed to parse id", http.StatusInternalServerError} + } + err = handler.store.Delete(int(guestID)) + if err != nil { + return &appError{err, "failed to get guests", http.StatusInternalServerError} + } + return nil +} + +func getID(request *http.Request) (int64, error) { + return strconv.ParseInt(guestIDRegex.FindStringSubmatch(request.URL.Path)[1], 10, 32) +} diff --git a/server/guest/store.go b/server/guest/store.go index a5b9374..4290f8f 100644 --- a/server/guest/store.go +++ b/server/guest/store.go @@ -157,7 +157,7 @@ func (store Store) Update(guest Guest) error { if err := store.updateGuest(guest); err != nil { return err } - if err := store.deleteOldParty(guest.ID); err != nil { + if err := store.deleteParty(guest.ID); err != nil { return err } return store.insertParty(guest) @@ -171,7 +171,7 @@ func (store Store) updateGuest(guest Guest) error { return err } -func (store Store) deleteOldParty(guestID int) error { +func (store Store) deleteParty(guestID int) error { statement := "delete from party where guest_id = $1" _, err := store.database.Exec(context.Background(), statement, guestID) return err @@ -189,3 +189,19 @@ func (store Store) insertParty(guest Guest) error { } return nil } + +func (store Store) Delete(guestID int) error { + if err := store.deleteGuest(guestID); err != nil { + return err + } + if err := store.deleteParty(guestID); err != nil { + return err + } + return nil +} + +func (store Store) deleteGuest(guestID int) error { + statement := "delete from guest where id = $1" + _, err := store.database.Exec(context.Background(), statement, guestID) + return err +} |