From f9b3c68e31a90e8091da69b9defcb6641697e9c8 Mon Sep 17 00:00:00 2001 From: Michael Hunteman Date: Sat, 17 Aug 2024 12:28:39 -0700 Subject: Secure other endpoints --- server/cmd/main.go | 101 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 67 insertions(+), 34 deletions(-) diff --git a/server/cmd/main.go b/server/cmd/main.go index 6931781..66f31b2 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -101,35 +101,53 @@ func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, case request.Method == http.MethodOptions: responseWriter.WriteHeader(http.StatusOK) case request.Method == http.MethodPost && request.URL.Path == "/guest/login": - token, err := guestHandler.login(request) - if err != nil { - http.Error(responseWriter, err.Message, err.Code) - } else { - responseWriter.Write(token) - } + guestHandler.handleLogIn(responseWriter, request) case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): - if err := guestHandler.putGuest(request); err != nil { - http.Error(responseWriter, err.Message, err.Code) - } + guestHandler.handlePut(responseWriter, request) case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): - guests, err := guestHandler.getGuests() - if err != nil { - http.Error(responseWriter, err.Message, err.Code) - } else { - responseWriter.Write(guests) - } + guestHandler.handleGet(responseWriter, request) case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): - if err := guestHandler.postGuest(request); err != nil { - http.Error(responseWriter, err.Message, err.Code) - } else { - responseWriter.WriteHeader(http.StatusOK) - } + guestHandler.handlePost(responseWriter, request) default: responseWriter.WriteHeader(http.StatusNotFound) } } -func (guestHandler *guestHandler) login(request *http.Request) ([]byte, *appError) { +func (guestHandler *guestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { + token, err := guestHandler.logInGuest(request) + if err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.Write(token) + } +} + +func (guestHandler *guestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) { + if err := guestHandler.putGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.WriteHeader(http.StatusOK) + } +} + +func (guestHandler *guestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) { + guests, err := guestHandler.getGuests(request) + if err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.Write(guests) + } +} + +func (guestHandler *guestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) { + if err := guestHandler.postGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.WriteHeader(http.StatusOK) + } +} + +func (guestHandler *guestHandler) logInGuest(request *http.Request) ([]byte, *appError) { credentials, err := guestHandler.decodeCredentials(request) if err != nil { return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest} @@ -176,6 +194,7 @@ func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, ex } func (guestHandler *guestHandler) readKey() ([]byte, error) { + // TODO: use properties file return os.ReadFile("C:\\Users\\mhunt\\skey.pem") } @@ -184,8 +203,8 @@ func (guestHandler *guestHandler) createToken(claims *guest.Claims, key []byte) return token.SignedString(key) } -func (guestHandler *guestHandler) marshalResponse(weddingGuest guest.Guest, token string) ([]byte, error) { - loginResponse := guestHandler.createLoginResponse(weddingGuest, token) +func (guestHandler *guestHandler) marshalResponse(guest guest.Guest, token string) ([]byte, error) { + loginResponse := guestHandler.createLoginResponse(guest, token) return json.Marshal(loginResponse) } @@ -197,6 +216,23 @@ func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, } func (guestHandler *guestHandler) putGuest(request *http.Request) *appError { + if err := guestHandler.validateToken(request); err != nil { + return err + } + if guestHandler.findId(request) { + return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound} + } + guest, err := guestHandler.decodeGuest(request) + if err != nil { + return &appError{err, "Invalid guest", http.StatusBadRequest} + } + if err := guestHandler.store.Update(guest); err != nil { + return &appError{err, "Failed to update guest", http.StatusInternalServerError} + } + return nil +} + +func (guestHandler *guestHandler) validateToken(request *http.Request) *appError { authorizationHeader := guestHandler.getToken(request) claims := guestHandler.initializeClaims() key, err := guestHandler.readKey() @@ -213,16 +249,6 @@ func (guestHandler *guestHandler) putGuest(request *http.Request) *appError { if !token.Valid { return &appError{err, "Invalid token", http.StatusUnauthorized} } - if guestHandler.findId(request) { - return &appError{err, "ID not found", http.StatusNotFound} - } - guest, err := guestHandler.decodeGuest(request) - if err != nil { - return &appError{err, "Invalid guest", http.StatusBadRequest} - } - if guestHandler.store.Update(guest) != nil { - return &appError{err, "Failed to update guest", http.StatusInternalServerError} - } return nil } @@ -252,7 +278,11 @@ func (guestHandler *guestHandler) decodeGuest(request *http.Request) (guest.Gues return guest, err } -func (guestHandler *guestHandler) getGuests() ([]byte, *appError) { +func (guestHandler *guestHandler) getGuests(request *http.Request) ([]byte, *appError) { + // TODO: check with admin token + if err := guestHandler.validateToken(request); err != nil { + return []byte{}, err + } guests, err := guestHandler.store.Get() if err != nil { return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError} @@ -265,6 +295,9 @@ func (guestHandler *guestHandler) getGuests() ([]byte, *appError) { } func (guestHandler *guestHandler) postGuest(request *http.Request) *appError { + if err := guestHandler.validateToken(request); err != nil { + return err + } guest, err := guestHandler.decodeGuest(request) if err != nil { return &appError{err, "Invalid guest", http.StatusBadRequest} -- cgit v1.2.3