From e3b6f82d1c8cdeb1edeccd429d0a6b4463ba9c0d Mon Sep 17 00:00:00 2001
From: Michael Hunteman <michael@huntm.net>
Date: Sun, 11 Aug 2024 15:11:30 -0700
Subject: Fix error handling

---
 server/cmd/main.go | 234 +++++++++++++++++++++++++----------------------------
 1 file changed, 111 insertions(+), 123 deletions(-)

(limited to 'server')

diff --git a/server/cmd/main.go b/server/cmd/main.go
index 53ef94e..6931781 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -3,6 +3,7 @@ package main
 import (
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"log"
 	"net/http"
@@ -28,6 +29,12 @@ type guestStore interface {
 	Update(guest guest.Guest) error
 }
 
+type appError struct {
+	Error   error
+	Message string
+	Code    int
+}
+
 var (
 	user     = os.Getenv("USER")
 	pass     = os.Getenv("PASS")
@@ -94,47 +101,65 @@ func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter,
 	case request.Method == http.MethodOptions:
 		responseWriter.WriteHeader(http.StatusOK)
 	case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
-		guestHandler.login(responseWriter, request)
+		token, err := guestHandler.login(request)
+		if err != nil {
+			http.Error(responseWriter, err.Message, err.Code)
+		} else {
+			responseWriter.Write(token)
+		}
 	case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
-		guestHandler.putGuest(responseWriter, request)
+		if err := guestHandler.putGuest(request); err != nil {
+			http.Error(responseWriter, err.Message, err.Code)
+		}
 	case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
-		guestHandler.getGuests(responseWriter)
+		guests, err := guestHandler.getGuests()
+		if err != nil {
+			http.Error(responseWriter, err.Message, err.Code)
+		} else {
+			responseWriter.Write(guests)
+		}
 	case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
-		guestHandler.postGuest(responseWriter, request)
+		if err := guestHandler.postGuest(request); err != nil {
+			http.Error(responseWriter, err.Message, err.Code)
+		} else {
+			responseWriter.WriteHeader(http.StatusOK)
+		}
 	default:
 		responseWriter.WriteHeader(http.StatusNotFound)
 	}
 }
 
-func (guestHandler *guestHandler) login(responseWriter http.ResponseWriter, request *http.Request) {
-	credentials := guestHandler.decodeCredentials(responseWriter, request)
-	guest := guestHandler.findGuest(responseWriter, credentials)
+func (guestHandler *guestHandler) login(request *http.Request) ([]byte, *appError) {
+	credentials, err := guestHandler.decodeCredentials(request)
+	if err != nil {
+		return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest}
+	}
+	guest, err := guestHandler.store.FindGuest(credentials)
+	if err != nil {
+		return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized}
+	}
 	expirationTime := guestHandler.setExpirationTime()
 	claims := guestHandler.createClaims(credentials, expirationTime)
-	key := guestHandler.readKey(responseWriter)
-	token := guestHandler.createToken(responseWriter, claims, key)
-	jsonBytes := guestHandler.marshalResponse(responseWriter, guest, token)
-	responseWriter.Write(jsonBytes)
+	key, err := guestHandler.readKey()
+	if err != nil {
+		return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError}
+	}
+	token, err := guestHandler.createToken(claims, key)
+	if err != nil {
+		return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError}
+	}
+	jsonBytes, err := guestHandler.marshalResponse(guest, token)
+	if err != nil {
+		return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError}
+	}
+	return jsonBytes, nil
 }
 
-func (guestHandler *guestHandler) decodeCredentials(responseWriter http.ResponseWriter, request *http.Request) guest.Credentials {
+func (guestHandler *guestHandler) decodeCredentials(request *http.Request) (guest.Credentials, error) {
 	var credentials guest.Credentials
 	err := json.NewDecoder(request.Body).Decode(&credentials)
-	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
-		log.Panic("Failed to unmarshal credentials")
-	}
 	defer request.Body.Close()
-	return credentials
-}
-
-func (guestHandler *guestHandler) findGuest(responseWriter http.ResponseWriter, credentials guest.Credentials) guest.Guest {
-	guest, err := guestHandler.store.FindGuest(credentials)
-	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusUnauthorized)
-		log.Panic("Guest not found")
-	}
-	return guest
+	return credentials, err
 }
 
 func (guestHandler *guestHandler) setExpirationTime() time.Time {
@@ -150,33 +175,18 @@ func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, ex
 	}
 }
 
-func (guestHandler *guestHandler) readKey(responseWriter http.ResponseWriter) []byte {
-	key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem")
-	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusInternalServerError)
-		log.Panic("Failed to read secret key")
-	}
-	return key
+func (guestHandler *guestHandler) readKey() ([]byte, error) {
+	return os.ReadFile("C:\\Users\\mhunt\\skey.pem")
 }
 
-func (guestHandler *guestHandler) createToken(responseWriter http.ResponseWriter, claims *guest.Claims, key []byte) string {
+func (guestHandler *guestHandler) createToken(claims *guest.Claims, key []byte) (string, error) {
 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-	tokenString, err := token.SignedString(key)
-	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusInternalServerError)
-		log.Panic("Failed to read token")
-	}
-	return tokenString
+	return token.SignedString(key)
 }
 
-func (guestHandler *guestHandler) marshalResponse(responseWriter http.ResponseWriter, weddingGuest guest.Guest, token string) []byte {
+func (guestHandler *guestHandler) marshalResponse(weddingGuest guest.Guest, token string) ([]byte, error) {
 	loginResponse := guestHandler.createLoginResponse(weddingGuest, token)
-	jsonBytes, err := json.Marshal(loginResponse)
-	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
-		log.Panic("Failed to marshal response")
-	}
-	return jsonBytes
+	return json.Marshal(loginResponse)
 }
 
 func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse {
@@ -186,14 +196,34 @@ func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest,
 	}
 }
 
-func (guestHandler *guestHandler) putGuest(responseWriter http.ResponseWriter, request *http.Request) {
-	token := guestHandler.getToken(request)
+func (guestHandler *guestHandler) putGuest(request *http.Request) *appError {
+	authorizationHeader := guestHandler.getToken(request)
 	claims := guestHandler.initializeClaims()
-	key := guestHandler.readKey(responseWriter)
-	guestHandler.checkClaims(responseWriter, token, claims, key)
-	guestHandler.findId(responseWriter, request)
-	guest := guestHandler.decodeGuest(responseWriter, request)
-	guestHandler.updateGuest(responseWriter, guest)
+	key, err := guestHandler.readKey()
+	if err != nil {
+		return &appError{err, "Failed to read secret key", http.StatusInternalServerError}
+	}
+	token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key)
+	if err != nil {
+		if err == jwt.ErrSignatureInvalid {
+			return &appError{err, "Invalid signature", http.StatusUnauthorized}
+		}
+		return &appError{err, "Failed to parse claims", http.StatusBadRequest}
+	}
+	if !token.Valid {
+		return &appError{err, "Invalid token", http.StatusUnauthorized}
+	}
+	if guestHandler.findId(request) {
+		return &appError{err, "ID not found", http.StatusNotFound}
+	}
+	guest, err := guestHandler.decodeGuest(request)
+	if err != nil {
+		return &appError{err, "Invalid guest", http.StatusBadRequest}
+	}
+	if guestHandler.store.Update(guest) != nil {
+		return &appError{err, "Failed to update guest", http.StatusInternalServerError}
+	}
+	return nil
 }
 
 func (guestHandler *guestHandler) getToken(request *http.Request) string {
@@ -204,101 +234,59 @@ func (guestHandler *guestHandler) initializeClaims() *guest.Claims {
 	return &guest.Claims{}
 }
 
-func (guestHandler *guestHandler) checkClaims(responseWriter http.ResponseWriter, tokenString string, claims *guest.Claims, key []byte) *jwt.Token {
-	token, err := guestHandler.parseWithClaims(tokenString, claims, key)
-	if err != nil {
-		if err == jwt.ErrSignatureInvalid {
-			responseWriter.WriteHeader(http.StatusUnauthorized)
-			log.Panic("Invalid signature")
-		}
-		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
-		log.Panic("Failed to parse claims")
-	}
-	if !token.Valid {
-		responseWriter.WriteHeader(http.StatusUnauthorized)
-		log.Panic("Invalid token")
-	}
-	return token
-}
-
 func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) {
 	return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
 		return key, nil
 	})
 }
 
-func (guestHandler *guestHandler) findId(responseWriter http.ResponseWriter, request *http.Request) {
+func (guestHandler *guestHandler) findId(request *http.Request) bool {
 	matches := guestIdRe.FindStringSubmatch(request.URL.Path)
-	if len(matches) < 2 {
-		http.Error(responseWriter, "ID not found", http.StatusBadRequest)
-		log.Panic("ID not found")
-	}
+	return len(matches) < 2
 }
 
-func (guestHandler *guestHandler) decodeGuest(responseWriter http.ResponseWriter, request *http.Request) guest.Guest {
+func (guestHandler *guestHandler) decodeGuest(request *http.Request) (guest.Guest, error) {
 	var guest guest.Guest
 	err := json.NewDecoder(request.Body).Decode(&guest)
-	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
-		log.Panic("Invalid guest")
-	}
 	defer request.Body.Close()
-	return guest
-}
-
-func (guestHandler *guestHandler) updateGuest(responseWriter http.ResponseWriter, guest guest.Guest) {
-	err := guestHandler.store.Update(guest)
-	if err != nil {
-		http.Error(responseWriter, "Failed to update guest", http.StatusBadRequest)
-		log.Panic("Failed to update guest")
-	}
-}
-
-func (guestHandler *guestHandler) getGuests(responseWriter http.ResponseWriter) {
-	guests := guestHandler.findGuests(responseWriter)
-	jsonBytes := guestHandler.marshalGuests(responseWriter, guests)
-	responseWriter.Write(jsonBytes)
+	return guest, err
 }
 
-func (guestHandler *guestHandler) findGuests(responseWriter http.ResponseWriter) []guest.Guest {
+func (guestHandler *guestHandler) getGuests() ([]byte, *appError) {
 	guests, err := guestHandler.store.Get()
 	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
-		log.Panic("Failed to find guests")
+		return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError}
 	}
-	return guests
-}
-
-func (guestHandler *guestHandler) marshalGuests(responseWriter http.ResponseWriter, guests []guest.Guest) []byte {
 	jsonBytes, err := json.Marshal(guests)
 	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
-		log.Panic("Failed to marshal guests")
+		return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError}
 	}
-	return jsonBytes
+	return jsonBytes, nil
 }
 
-func (guestHandler *guestHandler) postGuest(responseWriter http.ResponseWriter, request *http.Request) {
-	guest := guestHandler.decodeGuest(responseWriter, request)
-	guests := guestHandler.findGuests(responseWriter)
-	guestHandler.checkExistingGuests(responseWriter, guests, guest)
-	guestHandler.addGuest(responseWriter, guest)
-	responseWriter.WriteHeader(http.StatusOK)
+func (guestHandler *guestHandler) postGuest(request *http.Request) *appError {
+	guest, err := guestHandler.decodeGuest(request)
+	if err != nil {
+		return &appError{err, "Invalid guest", http.StatusBadRequest}
+	}
+	guests, err := guestHandler.store.Get()
+	if err != nil {
+		return &appError{err, "Failed to get guests", http.StatusInternalServerError}
+	}
+	if err := guestHandler.checkExistingGuests(guests, guest); err != nil {
+		return &appError{err, "ID already exists", http.StatusConflict}
+	}
+	if err := guestHandler.store.Add(guest); err != nil {
+		return &appError{err, "Failed to add guest", http.StatusInternalServerError}
+	}
+	return nil
 }
 
-func (guestHandler *guestHandler) checkExistingGuests(responseWriter http.ResponseWriter, guests []guest.Guest, newGuest guest.Guest) {
+func (guestHandler *guestHandler) checkExistingGuests(guests []guest.Guest, newGuest guest.Guest) error {
 	for _, guest := range guests {
 		if guest.Id == newGuest.Id {
-			http.Error(responseWriter, "ID already exists", http.StatusBadRequest)
-			log.Panic("ID already exists")
+			return errors.New("ID already exists")
 		}
 	}
-}
-
-func (guestHandler *guestHandler) addGuest(responseWriter http.ResponseWriter, guest guest.Guest) {
-	err := guestHandler.store.Add(guest)
-	if err != nil {
-		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
-		return
-	}
+	return nil
 }
-- 
cgit v1.2.3