From b6b590ce1786289642c85a82ac35eb81fb0ee3bb Mon Sep 17 00:00:00 2001
From: Michael Hunteman <michael@huntm.net>
Date: Sat, 10 Aug 2024 15:22:58 -0700
Subject: Refactor

---
 server/cmd/main.go | 339 +++++++++++++++++++++++++++++++----------------------
 1 file changed, 197 insertions(+), 142 deletions(-)

(limited to 'server')

diff --git a/server/cmd/main.go b/server/cmd/main.go
index b747aba..53ef94e 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -22,7 +22,7 @@ type guestHandler struct {
 }
 
 type guestStore interface {
-	FindGuest(creds guest.Credentials) (guest.Guest, error)
+	FindGuest(credentials guest.Credentials) (guest.Guest, error)
 	Get() ([]guest.Guest, error)
 	Add(guest guest.Guest) error
 	Update(guest guest.Guest) error
@@ -39,211 +39,266 @@ var (
 	guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
 )
 
-func newGuestHandler(s guestStore) *guestHandler {
+func main() {
+	db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database))
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer db.Close()
+
+	store := guest.NewMemStore(db)
+	guestHandler := createGuestHandler(store)
+
+	mux := http.NewServeMux()
+	mux.Handle("/guest/", guestHandler)
+	log.Fatal(http.ListenAndServe(":8080", enableCors(mux)))
+}
+
+func createGuestHandler(s guestStore) *guestHandler {
 	return &guestHandler{
 		store: s,
 	}
 }
 
-func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) {
-	var creds guest.Credentials
-	err := json.NewDecoder(r.Body).Decode(&creds)
+func enableCors(next http.Handler) http.Handler {
+	allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
+	allowedMethods := []string{"OPTIONS", "POST", "PUT"}
+	return serveHttp(next, allowedOrigins, allowedMethods)
+}
+
+func serveHttp(next http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		method := r.Header.Get("Access-Control-Request-Method")
+		if isPreflight(r) && slices.Contains(allowedMethods, method) {
+			w.Header().Add("Access-Control-Allow-Methods", method)
+		}
+
+		origin := r.Header.Get("Origin")
+		if slices.Contains(allowedOrigins, origin) {
+			w.Header().Add("Access-Control-Allow-Origin", origin)
+		}
+
+		w.Header().Add("Access-Control-Allow-Headers", "*")
+		next.ServeHTTP(w, r)
+	})
+}
+
+func isPreflight(r *http.Request) bool {
+	return r.Method == "OPTIONS" &&
+		r.Header.Get("Origin") != "" &&
+		r.Header.Get("Access-Control-Request-Method") != ""
+}
+
+func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
+	switch {
+	case request.Method == http.MethodOptions:
+		responseWriter.WriteHeader(http.StatusOK)
+	case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
+		guestHandler.login(responseWriter, request)
+	case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
+		guestHandler.putGuest(responseWriter, request)
+	case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
+		guestHandler.getGuests(responseWriter)
+	case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
+		guestHandler.postGuest(responseWriter, request)
+	default:
+		responseWriter.WriteHeader(http.StatusNotFound)
+	}
+}
+
+func (guestHandler *guestHandler) login(responseWriter http.ResponseWriter, request *http.Request) {
+	credentials := guestHandler.decodeCredentials(responseWriter, request)
+	guest := guestHandler.findGuest(responseWriter, credentials)
+	expirationTime := guestHandler.setExpirationTime()
+	claims := guestHandler.createClaims(credentials, expirationTime)
+	key := guestHandler.readKey(responseWriter)
+	token := guestHandler.createToken(responseWriter, claims, key)
+	jsonBytes := guestHandler.marshalResponse(responseWriter, guest, token)
+	responseWriter.Write(jsonBytes)
+}
+
+func (guestHandler *guestHandler) decodeCredentials(responseWriter http.ResponseWriter, request *http.Request) guest.Credentials {
+	var credentials guest.Credentials
+	err := json.NewDecoder(request.Body).Decode(&credentials)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
+		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+		log.Panic("Failed to unmarshal credentials")
 	}
-	defer r.Body.Close()
+	defer request.Body.Close()
+	return credentials
+}
 
-	weddingGuest, err := h.store.FindGuest(creds)
+func (guestHandler *guestHandler) findGuest(responseWriter http.ResponseWriter, credentials guest.Credentials) guest.Guest {
+	guest, err := guestHandler.store.FindGuest(credentials)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusUnauthorized)
-		return
+		http.Error(responseWriter, err.Error(), http.StatusUnauthorized)
+		log.Panic("Guest not found")
 	}
+	return guest
+}
 
-	expirationTime := time.Now().Add(15 * time.Minute)
-	claims := &guest.Claims{
-		Credentials: creds,
+func (guestHandler *guestHandler) setExpirationTime() time.Time {
+	return time.Now().Add(15 * time.Minute)
+}
+
+func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, expirationTime time.Time) *guest.Claims {
+	return &guest.Claims{
+		Credentials: credentials,
 		RegisteredClaims: jwt.RegisteredClaims{
 			ExpiresAt: jwt.NewNumericDate(expirationTime),
 		},
 	}
+}
 
+func (guestHandler *guestHandler) readKey(responseWriter http.ResponseWriter) []byte {
 	key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem")
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
+		http.Error(responseWriter, err.Error(), http.StatusInternalServerError)
+		log.Panic("Failed to read secret key")
 	}
+	return key
+}
 
+func (guestHandler *guestHandler) createToken(responseWriter http.ResponseWriter, claims *guest.Claims, key []byte) string {
 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 	tokenString, err := token.SignedString(key)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
-	}
-
-	loginResponse := &guest.LoginResponse{
-		Guest: weddingGuest,
-		Token: tokenString,
+		http.Error(responseWriter, err.Error(), http.StatusInternalServerError)
+		log.Panic("Failed to read token")
 	}
+	return tokenString
+}
 
+func (guestHandler *guestHandler) marshalResponse(responseWriter http.ResponseWriter, weddingGuest guest.Guest, token string) []byte {
+	loginResponse := guestHandler.createLoginResponse(weddingGuest, token)
 	jsonBytes, err := json.Marshal(loginResponse)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
+		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+		log.Panic("Failed to marshal response")
 	}
-
-	w.WriteHeader(http.StatusOK)
-	w.Write(jsonBytes)
+	return jsonBytes
 }
 
-func (h *guestHandler) getGuests(w http.ResponseWriter, _ *http.Request) {
-	guests, err := h.store.Get()
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
+func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse {
+	return &guest.LoginResponse{
+		Guest: weddingGuest,
+		Token: token,
 	}
+}
 
-	jsonBytes, err := json.Marshal(guests)
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
-	}
+func (guestHandler *guestHandler) putGuest(responseWriter http.ResponseWriter, request *http.Request) {
+	token := guestHandler.getToken(request)
+	claims := guestHandler.initializeClaims()
+	key := guestHandler.readKey(responseWriter)
+	guestHandler.checkClaims(responseWriter, token, claims, key)
+	guestHandler.findId(responseWriter, request)
+	guest := guestHandler.decodeGuest(responseWriter, request)
+	guestHandler.updateGuest(responseWriter, guest)
+}
 
-	w.WriteHeader(http.StatusOK)
-	w.Write(jsonBytes)
+func (guestHandler *guestHandler) getToken(request *http.Request) string {
+	return request.Header.Get("Authorization")
 }
 
-func (h *guestHandler) createGuest(w http.ResponseWriter, r *http.Request) {
-	var guest guest.Guest
-	err := json.NewDecoder(r.Body).Decode(&guest)
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
-	}
-	defer r.Body.Close()
+func (guestHandler *guestHandler) initializeClaims() *guest.Claims {
+	return &guest.Claims{}
+}
 
-	guestSlice, err := h.store.Get()
+func (guestHandler *guestHandler) checkClaims(responseWriter http.ResponseWriter, tokenString string, claims *guest.Claims, key []byte) *jwt.Token {
+	token, err := guestHandler.parseWithClaims(tokenString, claims, key)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
-	}
-
-	for _, g := range guestSlice {
-		if g.Id == guest.Id {
-			http.Error(w, "Id already exists", http.StatusBadRequest)
-			return
+		if err == jwt.ErrSignatureInvalid {
+			responseWriter.WriteHeader(http.StatusUnauthorized)
+			log.Panic("Invalid signature")
 		}
+		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+		log.Panic("Failed to parse claims")
 	}
-
-	err = h.store.Add(guest)
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
+	if !token.Valid {
+		responseWriter.WriteHeader(http.StatusUnauthorized)
+		log.Panic("Invalid token")
 	}
-
-	w.WriteHeader(http.StatusOK)
+	return token
 }
 
-func (h *guestHandler) updateGuest(w http.ResponseWriter, r *http.Request) {
-	tokenString := r.Header.Get("Authorization")
-	claims := &guest.Claims{}
-
-	key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem")
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
-	}
-
-	token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) {
+func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) {
+	return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
 		return key, nil
 	})
-	if err != nil {
-		if err == jwt.ErrSignatureInvalid {
-			w.WriteHeader(http.StatusUnauthorized)
-			return
-		}
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
-	}
-	if !token.Valid {
-		w.WriteHeader(http.StatusUnauthorized)
-		return
-	}
+}
 
-	matches := guestIdRe.FindStringSubmatch(r.URL.Path)
+func (guestHandler *guestHandler) findId(responseWriter http.ResponseWriter, request *http.Request) {
+	matches := guestIdRe.FindStringSubmatch(request.URL.Path)
 	if len(matches) < 2 {
-		http.Error(w, "No id found", http.StatusBadRequest)
-		return
+		http.Error(responseWriter, "ID not found", http.StatusBadRequest)
+		log.Panic("ID not found")
 	}
+}
 
+func (guestHandler *guestHandler) decodeGuest(responseWriter http.ResponseWriter, request *http.Request) guest.Guest {
 	var guest guest.Guest
-	err = json.NewDecoder(r.Body).Decode(&guest)
+	err := json.NewDecoder(request.Body).Decode(&guest)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusBadRequest)
-		return
+		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+		log.Panic("Invalid guest")
 	}
-	defer r.Body.Close()
+	defer request.Body.Close()
+	return guest
+}
 
-	err = h.store.Update(guest)
+func (guestHandler *guestHandler) updateGuest(responseWriter http.ResponseWriter, guest guest.Guest) {
+	err := guestHandler.store.Update(guest)
 	if err != nil {
-		http.Error(w, "Cannot update guest", http.StatusBadRequest)
-		return
+		http.Error(responseWriter, "Failed to update guest", http.StatusBadRequest)
+		log.Panic("Failed to update guest")
 	}
 }
 
-func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	switch {
-	case r.Method == http.MethodOptions:
-		w.WriteHeader(http.StatusOK)
-	case r.Method == http.MethodPost && r.URL.Path == "/guest/login":
-		h.login(w, r)
-	case r.Method == http.MethodGet && guestRe.MatchString(r.URL.Path):
-		h.getGuests(w, r)
-	case r.Method == http.MethodPost && guestRe.MatchString(r.URL.Path):
-		h.createGuest(w, r)
-	case r.Method == http.MethodPut && guestIdRe.MatchString(r.URL.Path):
-		h.updateGuest(w, r)
-	default:
-		w.WriteHeader(http.StatusNotFound)
-	}
+func (guestHandler *guestHandler) getGuests(responseWriter http.ResponseWriter) {
+	guests := guestHandler.findGuests(responseWriter)
+	jsonBytes := guestHandler.marshalGuests(responseWriter, guests)
+	responseWriter.Write(jsonBytes)
 }
 
-func enableCors(next http.Handler) http.Handler {
-	allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
-	allowedMethods := []string{"OPTIONS", "POST", "PUT"}
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		method := r.Header.Get("Access-Control-Request-Method")
-		if isPreflight(r) && slices.Contains(allowedMethods, method) {
-			w.Header().Add("Access-Control-Allow-Methods", method)
-		}
+func (guestHandler *guestHandler) findGuests(responseWriter http.ResponseWriter) []guest.Guest {
+	guests, err := guestHandler.store.Get()
+	if err != nil {
+		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+		log.Panic("Failed to find guests")
+	}
+	return guests
+}
 
-		origin := r.Header.Get("Origin")
-		if slices.Contains(allowedOrigins, origin) {
-			w.Header().Add("Access-Control-Allow-Origin", origin)
-		}
+func (guestHandler *guestHandler) marshalGuests(responseWriter http.ResponseWriter, guests []guest.Guest) []byte {
+	jsonBytes, err := json.Marshal(guests)
+	if err != nil {
+		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+		log.Panic("Failed to marshal guests")
+	}
+	return jsonBytes
+}
 
-		w.Header().Add("Access-Control-Allow-Headers", "*")
-		next.ServeHTTP(w, r)
-	})
+func (guestHandler *guestHandler) postGuest(responseWriter http.ResponseWriter, request *http.Request) {
+	guest := guestHandler.decodeGuest(responseWriter, request)
+	guests := guestHandler.findGuests(responseWriter)
+	guestHandler.checkExistingGuests(responseWriter, guests, guest)
+	guestHandler.addGuest(responseWriter, guest)
+	responseWriter.WriteHeader(http.StatusOK)
 }
 
-func isPreflight(r *http.Request) bool {
-	return r.Method == "OPTIONS" &&
-		r.Header.Get("Origin") != "" &&
-		r.Header.Get("Access-Control-Request-Method") != ""
+func (guestHandler *guestHandler) checkExistingGuests(responseWriter http.ResponseWriter, guests []guest.Guest, newGuest guest.Guest) {
+	for _, guest := range guests {
+		if guest.Id == newGuest.Id {
+			http.Error(responseWriter, "ID already exists", http.StatusBadRequest)
+			log.Panic("ID already exists")
+		}
+	}
 }
 
-func main() {
-	db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database))
+func (guestHandler *guestHandler) addGuest(responseWriter http.ResponseWriter, guest guest.Guest) {
+	err := guestHandler.store.Add(guest)
 	if err != nil {
-		log.Fatal(err)
+		http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+		return
 	}
-	defer db.Close()
-
-	store := guest.NewMemStore(db)
-	guestHandler := newGuestHandler(store)
-
-	mux := http.NewServeMux()
-	mux.Handle("/guest/", guestHandler)
-	log.Fatal(http.ListenAndServe(":8080", enableCors(mux)))
 }
-- 
cgit v1.2.3