From b6b590ce1786289642c85a82ac35eb81fb0ee3bb Mon Sep 17 00:00:00 2001 From: Michael Hunteman Date: Sat, 10 Aug 2024 15:22:58 -0700 Subject: Refactor --- server/cmd/main.go | 339 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 197 insertions(+), 142 deletions(-) (limited to 'server/cmd') diff --git a/server/cmd/main.go b/server/cmd/main.go index b747aba..53ef94e 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -22,7 +22,7 @@ type guestHandler struct { } type guestStore interface { - FindGuest(creds guest.Credentials) (guest.Guest, error) + FindGuest(credentials guest.Credentials) (guest.Guest, error) Get() ([]guest.Guest, error) Add(guest guest.Guest) error Update(guest guest.Guest) error @@ -39,211 +39,266 @@ var ( guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`) ) -func newGuestHandler(s guestStore) *guestHandler { +func main() { + db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database)) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + store := guest.NewMemStore(db) + guestHandler := createGuestHandler(store) + + mux := http.NewServeMux() + mux.Handle("/guest/", guestHandler) + log.Fatal(http.ListenAndServe(":8080", enableCors(mux))) +} + +func createGuestHandler(s guestStore) *guestHandler { return &guestHandler{ store: s, } } -func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) { - var creds guest.Credentials - err := json.NewDecoder(r.Body).Decode(&creds) +func enableCors(next http.Handler) http.Handler { + allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"} + allowedMethods := []string{"OPTIONS", "POST", "PUT"} + return serveHttp(next, allowedOrigins, allowedMethods) +} + +func serveHttp(next http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + method := r.Header.Get("Access-Control-Request-Method") + if isPreflight(r) && slices.Contains(allowedMethods, method) { + w.Header().Add("Access-Control-Allow-Methods", method) + } + + origin := r.Header.Get("Origin") + if slices.Contains(allowedOrigins, origin) { + w.Header().Add("Access-Control-Allow-Origin", origin) + } + + w.Header().Add("Access-Control-Allow-Headers", "*") + next.ServeHTTP(w, r) + }) +} + +func isPreflight(r *http.Request) bool { + return r.Method == "OPTIONS" && + r.Header.Get("Origin") != "" && + r.Header.Get("Access-Control-Request-Method") != "" +} + +func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + switch { + case request.Method == http.MethodOptions: + responseWriter.WriteHeader(http.StatusOK) + case request.Method == http.MethodPost && request.URL.Path == "/guest/login": + guestHandler.login(responseWriter, request) + case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): + guestHandler.putGuest(responseWriter, request) + case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): + guestHandler.getGuests(responseWriter) + case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): + guestHandler.postGuest(responseWriter, request) + default: + responseWriter.WriteHeader(http.StatusNotFound) + } +} + +func (guestHandler *guestHandler) login(responseWriter http.ResponseWriter, request *http.Request) { + credentials := guestHandler.decodeCredentials(responseWriter, request) + guest := guestHandler.findGuest(responseWriter, credentials) + expirationTime := guestHandler.setExpirationTime() + claims := guestHandler.createClaims(credentials, expirationTime) + key := guestHandler.readKey(responseWriter) + token := guestHandler.createToken(responseWriter, claims, key) + jsonBytes := guestHandler.marshalResponse(responseWriter, guest, token) + responseWriter.Write(jsonBytes) +} + +func (guestHandler *guestHandler) decodeCredentials(responseWriter http.ResponseWriter, request *http.Request) guest.Credentials { + var credentials guest.Credentials + err := json.NewDecoder(request.Body).Decode(&credentials) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + http.Error(responseWriter, err.Error(), http.StatusBadRequest) + log.Panic("Failed to unmarshal credentials") } - defer r.Body.Close() + defer request.Body.Close() + return credentials +} - weddingGuest, err := h.store.FindGuest(creds) +func (guestHandler *guestHandler) findGuest(responseWriter http.ResponseWriter, credentials guest.Credentials) guest.Guest { + guest, err := guestHandler.store.FindGuest(credentials) if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return + http.Error(responseWriter, err.Error(), http.StatusUnauthorized) + log.Panic("Guest not found") } + return guest +} - expirationTime := time.Now().Add(15 * time.Minute) - claims := &guest.Claims{ - Credentials: creds, +func (guestHandler *guestHandler) setExpirationTime() time.Time { + return time.Now().Add(15 * time.Minute) +} + +func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, expirationTime time.Time) *guest.Claims { + return &guest.Claims{ + Credentials: credentials, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expirationTime), }, } +} +func (guestHandler *guestHandler) readKey(responseWriter http.ResponseWriter) []byte { key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem") if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + http.Error(responseWriter, err.Error(), http.StatusInternalServerError) + log.Panic("Failed to read secret key") } + return key +} +func (guestHandler *guestHandler) createToken(responseWriter http.ResponseWriter, claims *guest.Claims, key []byte) string { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(key) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - loginResponse := &guest.LoginResponse{ - Guest: weddingGuest, - Token: tokenString, + http.Error(responseWriter, err.Error(), http.StatusInternalServerError) + log.Panic("Failed to read token") } + return tokenString +} +func (guestHandler *guestHandler) marshalResponse(responseWriter http.ResponseWriter, weddingGuest guest.Guest, token string) []byte { + loginResponse := guestHandler.createLoginResponse(weddingGuest, token) jsonBytes, err := json.Marshal(loginResponse) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + http.Error(responseWriter, err.Error(), http.StatusBadRequest) + log.Panic("Failed to marshal response") } - - w.WriteHeader(http.StatusOK) - w.Write(jsonBytes) + return jsonBytes } -func (h *guestHandler) getGuests(w http.ResponseWriter, _ *http.Request) { - guests, err := h.store.Get() - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return +func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse { + return &guest.LoginResponse{ + Guest: weddingGuest, + Token: token, } +} - jsonBytes, err := json.Marshal(guests) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } +func (guestHandler *guestHandler) putGuest(responseWriter http.ResponseWriter, request *http.Request) { + token := guestHandler.getToken(request) + claims := guestHandler.initializeClaims() + key := guestHandler.readKey(responseWriter) + guestHandler.checkClaims(responseWriter, token, claims, key) + guestHandler.findId(responseWriter, request) + guest := guestHandler.decodeGuest(responseWriter, request) + guestHandler.updateGuest(responseWriter, guest) +} - w.WriteHeader(http.StatusOK) - w.Write(jsonBytes) +func (guestHandler *guestHandler) getToken(request *http.Request) string { + return request.Header.Get("Authorization") } -func (h *guestHandler) createGuest(w http.ResponseWriter, r *http.Request) { - var guest guest.Guest - err := json.NewDecoder(r.Body).Decode(&guest) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - defer r.Body.Close() +func (guestHandler *guestHandler) initializeClaims() *guest.Claims { + return &guest.Claims{} +} - guestSlice, err := h.store.Get() +func (guestHandler *guestHandler) checkClaims(responseWriter http.ResponseWriter, tokenString string, claims *guest.Claims, key []byte) *jwt.Token { + token, err := guestHandler.parseWithClaims(tokenString, claims, key) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - for _, g := range guestSlice { - if g.Id == guest.Id { - http.Error(w, "Id already exists", http.StatusBadRequest) - return + if err == jwt.ErrSignatureInvalid { + responseWriter.WriteHeader(http.StatusUnauthorized) + log.Panic("Invalid signature") } + http.Error(responseWriter, err.Error(), http.StatusBadRequest) + log.Panic("Failed to parse claims") } - - err = h.store.Add(guest) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + if !token.Valid { + responseWriter.WriteHeader(http.StatusUnauthorized) + log.Panic("Invalid token") } - - w.WriteHeader(http.StatusOK) + return token } -func (h *guestHandler) updateGuest(w http.ResponseWriter, r *http.Request) { - tokenString := r.Header.Get("Authorization") - claims := &guest.Claims{} - - key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem") - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { +func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) { + return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { return key, nil }) - if err != nil { - if err == jwt.ErrSignatureInvalid { - w.WriteHeader(http.StatusUnauthorized) - return - } - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !token.Valid { - w.WriteHeader(http.StatusUnauthorized) - return - } +} - matches := guestIdRe.FindStringSubmatch(r.URL.Path) +func (guestHandler *guestHandler) findId(responseWriter http.ResponseWriter, request *http.Request) { + matches := guestIdRe.FindStringSubmatch(request.URL.Path) if len(matches) < 2 { - http.Error(w, "No id found", http.StatusBadRequest) - return + http.Error(responseWriter, "ID not found", http.StatusBadRequest) + log.Panic("ID not found") } +} +func (guestHandler *guestHandler) decodeGuest(responseWriter http.ResponseWriter, request *http.Request) guest.Guest { var guest guest.Guest - err = json.NewDecoder(r.Body).Decode(&guest) + err := json.NewDecoder(request.Body).Decode(&guest) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + http.Error(responseWriter, err.Error(), http.StatusBadRequest) + log.Panic("Invalid guest") } - defer r.Body.Close() + defer request.Body.Close() + return guest +} - err = h.store.Update(guest) +func (guestHandler *guestHandler) updateGuest(responseWriter http.ResponseWriter, guest guest.Guest) { + err := guestHandler.store.Update(guest) if err != nil { - http.Error(w, "Cannot update guest", http.StatusBadRequest) - return + http.Error(responseWriter, "Failed to update guest", http.StatusBadRequest) + log.Panic("Failed to update guest") } } -func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - switch { - case r.Method == http.MethodOptions: - w.WriteHeader(http.StatusOK) - case r.Method == http.MethodPost && r.URL.Path == "/guest/login": - h.login(w, r) - case r.Method == http.MethodGet && guestRe.MatchString(r.URL.Path): - h.getGuests(w, r) - case r.Method == http.MethodPost && guestRe.MatchString(r.URL.Path): - h.createGuest(w, r) - case r.Method == http.MethodPut && guestIdRe.MatchString(r.URL.Path): - h.updateGuest(w, r) - default: - w.WriteHeader(http.StatusNotFound) - } +func (guestHandler *guestHandler) getGuests(responseWriter http.ResponseWriter) { + guests := guestHandler.findGuests(responseWriter) + jsonBytes := guestHandler.marshalGuests(responseWriter, guests) + responseWriter.Write(jsonBytes) } -func enableCors(next http.Handler) http.Handler { - allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"} - allowedMethods := []string{"OPTIONS", "POST", "PUT"} - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - method := r.Header.Get("Access-Control-Request-Method") - if isPreflight(r) && slices.Contains(allowedMethods, method) { - w.Header().Add("Access-Control-Allow-Methods", method) - } +func (guestHandler *guestHandler) findGuests(responseWriter http.ResponseWriter) []guest.Guest { + guests, err := guestHandler.store.Get() + if err != nil { + http.Error(responseWriter, err.Error(), http.StatusBadRequest) + log.Panic("Failed to find guests") + } + return guests +} - origin := r.Header.Get("Origin") - if slices.Contains(allowedOrigins, origin) { - w.Header().Add("Access-Control-Allow-Origin", origin) - } +func (guestHandler *guestHandler) marshalGuests(responseWriter http.ResponseWriter, guests []guest.Guest) []byte { + jsonBytes, err := json.Marshal(guests) + if err != nil { + http.Error(responseWriter, err.Error(), http.StatusBadRequest) + log.Panic("Failed to marshal guests") + } + return jsonBytes +} - w.Header().Add("Access-Control-Allow-Headers", "*") - next.ServeHTTP(w, r) - }) +func (guestHandler *guestHandler) postGuest(responseWriter http.ResponseWriter, request *http.Request) { + guest := guestHandler.decodeGuest(responseWriter, request) + guests := guestHandler.findGuests(responseWriter) + guestHandler.checkExistingGuests(responseWriter, guests, guest) + guestHandler.addGuest(responseWriter, guest) + responseWriter.WriteHeader(http.StatusOK) } -func isPreflight(r *http.Request) bool { - return r.Method == "OPTIONS" && - r.Header.Get("Origin") != "" && - r.Header.Get("Access-Control-Request-Method") != "" +func (guestHandler *guestHandler) checkExistingGuests(responseWriter http.ResponseWriter, guests []guest.Guest, newGuest guest.Guest) { + for _, guest := range guests { + if guest.Id == newGuest.Id { + http.Error(responseWriter, "ID already exists", http.StatusBadRequest) + log.Panic("ID already exists") + } + } } -func main() { - db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database)) +func (guestHandler *guestHandler) addGuest(responseWriter http.ResponseWriter, guest guest.Guest) { + err := guestHandler.store.Add(guest) if err != nil { - log.Fatal(err) + http.Error(responseWriter, err.Error(), http.StatusBadRequest) + return } - defer db.Close() - - store := guest.NewMemStore(db) - guestHandler := newGuestHandler(store) - - mux := http.NewServeMux() - mux.Handle("/guest/", guestHandler) - log.Fatal(http.ListenAndServe(":8080", enableCors(mux))) } -- cgit v1.2.3