From e3b6f82d1c8cdeb1edeccd429d0a6b4463ba9c0d Mon Sep 17 00:00:00 2001 From: Michael Hunteman Date: Sun, 11 Aug 2024 15:11:30 -0700 Subject: Fix error handling --- server/cmd/main.go | 234 +++++++++++++++++++++++++---------------------------- 1 file changed, 111 insertions(+), 123 deletions(-) diff --git a/server/cmd/main.go b/server/cmd/main.go index 53ef94e..6931781 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "log" "net/http" @@ -28,6 +29,12 @@ type guestStore interface { Update(guest guest.Guest) error } +type appError struct { + Error error + Message string + Code int +} + var ( user = os.Getenv("USER") pass = os.Getenv("PASS") @@ -94,47 +101,65 @@ func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, case request.Method == http.MethodOptions: responseWriter.WriteHeader(http.StatusOK) case request.Method == http.MethodPost && request.URL.Path == "/guest/login": - guestHandler.login(responseWriter, request) + token, err := guestHandler.login(request) + if err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.Write(token) + } case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): - guestHandler.putGuest(responseWriter, request) + if err := guestHandler.putGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): - guestHandler.getGuests(responseWriter) + guests, err := guestHandler.getGuests() + if err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.Write(guests) + } case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): - guestHandler.postGuest(responseWriter, request) + if err := guestHandler.postGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.WriteHeader(http.StatusOK) + } default: responseWriter.WriteHeader(http.StatusNotFound) } } -func (guestHandler *guestHandler) login(responseWriter http.ResponseWriter, request *http.Request) { - credentials := guestHandler.decodeCredentials(responseWriter, request) - guest := guestHandler.findGuest(responseWriter, credentials) +func (guestHandler *guestHandler) login(request *http.Request) ([]byte, *appError) { + credentials, err := guestHandler.decodeCredentials(request) + if err != nil { + return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest} + } + guest, err := guestHandler.store.FindGuest(credentials) + if err != nil { + return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized} + } expirationTime := guestHandler.setExpirationTime() claims := guestHandler.createClaims(credentials, expirationTime) - key := guestHandler.readKey(responseWriter) - token := guestHandler.createToken(responseWriter, claims, key) - jsonBytes := guestHandler.marshalResponse(responseWriter, guest, token) - responseWriter.Write(jsonBytes) + key, err := guestHandler.readKey() + if err != nil { + return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError} + } + token, err := guestHandler.createToken(claims, key) + if err != nil { + return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError} + } + jsonBytes, err := guestHandler.marshalResponse(guest, token) + if err != nil { + return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError} + } + return jsonBytes, nil } -func (guestHandler *guestHandler) decodeCredentials(responseWriter http.ResponseWriter, request *http.Request) guest.Credentials { +func (guestHandler *guestHandler) decodeCredentials(request *http.Request) (guest.Credentials, error) { var credentials guest.Credentials err := json.NewDecoder(request.Body).Decode(&credentials) - if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - log.Panic("Failed to unmarshal credentials") - } defer request.Body.Close() - return credentials -} - -func (guestHandler *guestHandler) findGuest(responseWriter http.ResponseWriter, credentials guest.Credentials) guest.Guest { - guest, err := guestHandler.store.FindGuest(credentials) - if err != nil { - http.Error(responseWriter, err.Error(), http.StatusUnauthorized) - log.Panic("Guest not found") - } - return guest + return credentials, err } func (guestHandler *guestHandler) setExpirationTime() time.Time { @@ -150,33 +175,18 @@ func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, ex } } -func (guestHandler *guestHandler) readKey(responseWriter http.ResponseWriter) []byte { - key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem") - if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - log.Panic("Failed to read secret key") - } - return key +func (guestHandler *guestHandler) readKey() ([]byte, error) { + return os.ReadFile("C:\\Users\\mhunt\\skey.pem") } -func (guestHandler *guestHandler) createToken(responseWriter http.ResponseWriter, claims *guest.Claims, key []byte) string { +func (guestHandler *guestHandler) createToken(claims *guest.Claims, key []byte) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString(key) - if err != nil { - http.Error(responseWriter, err.Error(), http.StatusInternalServerError) - log.Panic("Failed to read token") - } - return tokenString + return token.SignedString(key) } -func (guestHandler *guestHandler) marshalResponse(responseWriter http.ResponseWriter, weddingGuest guest.Guest, token string) []byte { +func (guestHandler *guestHandler) marshalResponse(weddingGuest guest.Guest, token string) ([]byte, error) { loginResponse := guestHandler.createLoginResponse(weddingGuest, token) - jsonBytes, err := json.Marshal(loginResponse) - if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - log.Panic("Failed to marshal response") - } - return jsonBytes + return json.Marshal(loginResponse) } func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse { @@ -186,14 +196,34 @@ func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, } } -func (guestHandler *guestHandler) putGuest(responseWriter http.ResponseWriter, request *http.Request) { - token := guestHandler.getToken(request) +func (guestHandler *guestHandler) putGuest(request *http.Request) *appError { + authorizationHeader := guestHandler.getToken(request) claims := guestHandler.initializeClaims() - key := guestHandler.readKey(responseWriter) - guestHandler.checkClaims(responseWriter, token, claims, key) - guestHandler.findId(responseWriter, request) - guest := guestHandler.decodeGuest(responseWriter, request) - guestHandler.updateGuest(responseWriter, guest) + key, err := guestHandler.readKey() + if err != nil { + return &appError{err, "Failed to read secret key", http.StatusInternalServerError} + } + token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) + if err != nil { + if err == jwt.ErrSignatureInvalid { + return &appError{err, "Invalid signature", http.StatusUnauthorized} + } + return &appError{err, "Failed to parse claims", http.StatusBadRequest} + } + if !token.Valid { + return &appError{err, "Invalid token", http.StatusUnauthorized} + } + if guestHandler.findId(request) { + return &appError{err, "ID not found", http.StatusNotFound} + } + guest, err := guestHandler.decodeGuest(request) + if err != nil { + return &appError{err, "Invalid guest", http.StatusBadRequest} + } + if guestHandler.store.Update(guest) != nil { + return &appError{err, "Failed to update guest", http.StatusInternalServerError} + } + return nil } func (guestHandler *guestHandler) getToken(request *http.Request) string { @@ -204,101 +234,59 @@ func (guestHandler *guestHandler) initializeClaims() *guest.Claims { return &guest.Claims{} } -func (guestHandler *guestHandler) checkClaims(responseWriter http.ResponseWriter, tokenString string, claims *guest.Claims, key []byte) *jwt.Token { - token, err := guestHandler.parseWithClaims(tokenString, claims, key) - if err != nil { - if err == jwt.ErrSignatureInvalid { - responseWriter.WriteHeader(http.StatusUnauthorized) - log.Panic("Invalid signature") - } - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - log.Panic("Failed to parse claims") - } - if !token.Valid { - responseWriter.WriteHeader(http.StatusUnauthorized) - log.Panic("Invalid token") - } - return token -} - func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) { return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { return key, nil }) } -func (guestHandler *guestHandler) findId(responseWriter http.ResponseWriter, request *http.Request) { +func (guestHandler *guestHandler) findId(request *http.Request) bool { matches := guestIdRe.FindStringSubmatch(request.URL.Path) - if len(matches) < 2 { - http.Error(responseWriter, "ID not found", http.StatusBadRequest) - log.Panic("ID not found") - } + return len(matches) < 2 } -func (guestHandler *guestHandler) decodeGuest(responseWriter http.ResponseWriter, request *http.Request) guest.Guest { +func (guestHandler *guestHandler) decodeGuest(request *http.Request) (guest.Guest, error) { var guest guest.Guest err := json.NewDecoder(request.Body).Decode(&guest) - if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - log.Panic("Invalid guest") - } defer request.Body.Close() - return guest -} - -func (guestHandler *guestHandler) updateGuest(responseWriter http.ResponseWriter, guest guest.Guest) { - err := guestHandler.store.Update(guest) - if err != nil { - http.Error(responseWriter, "Failed to update guest", http.StatusBadRequest) - log.Panic("Failed to update guest") - } -} - -func (guestHandler *guestHandler) getGuests(responseWriter http.ResponseWriter) { - guests := guestHandler.findGuests(responseWriter) - jsonBytes := guestHandler.marshalGuests(responseWriter, guests) - responseWriter.Write(jsonBytes) + return guest, err } -func (guestHandler *guestHandler) findGuests(responseWriter http.ResponseWriter) []guest.Guest { +func (guestHandler *guestHandler) getGuests() ([]byte, *appError) { guests, err := guestHandler.store.Get() if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - log.Panic("Failed to find guests") + return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError} } - return guests -} - -func (guestHandler *guestHandler) marshalGuests(responseWriter http.ResponseWriter, guests []guest.Guest) []byte { jsonBytes, err := json.Marshal(guests) if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - log.Panic("Failed to marshal guests") + return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError} } - return jsonBytes + return jsonBytes, nil } -func (guestHandler *guestHandler) postGuest(responseWriter http.ResponseWriter, request *http.Request) { - guest := guestHandler.decodeGuest(responseWriter, request) - guests := guestHandler.findGuests(responseWriter) - guestHandler.checkExistingGuests(responseWriter, guests, guest) - guestHandler.addGuest(responseWriter, guest) - responseWriter.WriteHeader(http.StatusOK) +func (guestHandler *guestHandler) postGuest(request *http.Request) *appError { + guest, err := guestHandler.decodeGuest(request) + if err != nil { + return &appError{err, "Invalid guest", http.StatusBadRequest} + } + guests, err := guestHandler.store.Get() + if err != nil { + return &appError{err, "Failed to get guests", http.StatusInternalServerError} + } + if err := guestHandler.checkExistingGuests(guests, guest); err != nil { + return &appError{err, "ID already exists", http.StatusConflict} + } + if err := guestHandler.store.Add(guest); err != nil { + return &appError{err, "Failed to add guest", http.StatusInternalServerError} + } + return nil } -func (guestHandler *guestHandler) checkExistingGuests(responseWriter http.ResponseWriter, guests []guest.Guest, newGuest guest.Guest) { +func (guestHandler *guestHandler) checkExistingGuests(guests []guest.Guest, newGuest guest.Guest) error { for _, guest := range guests { if guest.Id == newGuest.Id { - http.Error(responseWriter, "ID already exists", http.StatusBadRequest) - log.Panic("ID already exists") + return errors.New("ID already exists") } } -} - -func (guestHandler *guestHandler) addGuest(responseWriter http.ResponseWriter, guest guest.Guest) { - err := guestHandler.store.Add(guest) - if err != nil { - http.Error(responseWriter, err.Error(), http.StatusBadRequest) - return - } + return nil } -- cgit v1.2.3