summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Hunteman <michael@huntm.net>2024-08-10 15:22:58 -0700
committerMichael Hunteman <michael@huntm.net>2024-08-10 15:22:58 -0700
commitb6b590ce1786289642c85a82ac35eb81fb0ee3bb (patch)
treedf0ee8910f539daf5c4c446e29dc18ca06bb6b15
parentaefab5f11ec64c3e4d8ffe506f2838553497cf65 (diff)
Refactor
-rw-r--r--server/cmd/main.go339
1 files changed, 197 insertions, 142 deletions
diff --git a/server/cmd/main.go b/server/cmd/main.go
index b747aba..53ef94e 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -22,7 +22,7 @@ type guestHandler struct {
}
type guestStore interface {
- FindGuest(creds guest.Credentials) (guest.Guest, error)
+ FindGuest(credentials guest.Credentials) (guest.Guest, error)
Get() ([]guest.Guest, error)
Add(guest guest.Guest) error
Update(guest guest.Guest) error
@@ -39,211 +39,266 @@ var (
guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
)
-func newGuestHandler(s guestStore) *guestHandler {
+func main() {
+ db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database))
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer db.Close()
+
+ store := guest.NewMemStore(db)
+ guestHandler := createGuestHandler(store)
+
+ mux := http.NewServeMux()
+ mux.Handle("/guest/", guestHandler)
+ log.Fatal(http.ListenAndServe(":8080", enableCors(mux)))
+}
+
+func createGuestHandler(s guestStore) *guestHandler {
return &guestHandler{
store: s,
}
}
-func (h *guestHandler) login(w http.ResponseWriter, r *http.Request) {
- var creds guest.Credentials
- err := json.NewDecoder(r.Body).Decode(&creds)
+func enableCors(next http.Handler) http.Handler {
+ allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
+ allowedMethods := []string{"OPTIONS", "POST", "PUT"}
+ return serveHttp(next, allowedOrigins, allowedMethods)
+}
+
+func serveHttp(next http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ method := r.Header.Get("Access-Control-Request-Method")
+ if isPreflight(r) && slices.Contains(allowedMethods, method) {
+ w.Header().Add("Access-Control-Allow-Methods", method)
+ }
+
+ origin := r.Header.Get("Origin")
+ if slices.Contains(allowedOrigins, origin) {
+ w.Header().Add("Access-Control-Allow-Origin", origin)
+ }
+
+ w.Header().Add("Access-Control-Allow-Headers", "*")
+ next.ServeHTTP(w, r)
+ })
+}
+
+func isPreflight(r *http.Request) bool {
+ return r.Method == "OPTIONS" &&
+ r.Header.Get("Origin") != "" &&
+ r.Header.Get("Access-Control-Request-Method") != ""
+}
+
+func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
+ switch {
+ case request.Method == http.MethodOptions:
+ responseWriter.WriteHeader(http.StatusOK)
+ case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
+ guestHandler.login(responseWriter, request)
+ case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
+ guestHandler.putGuest(responseWriter, request)
+ case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
+ guestHandler.getGuests(responseWriter)
+ case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
+ guestHandler.postGuest(responseWriter, request)
+ default:
+ responseWriter.WriteHeader(http.StatusNotFound)
+ }
+}
+
+func (guestHandler *guestHandler) login(responseWriter http.ResponseWriter, request *http.Request) {
+ credentials := guestHandler.decodeCredentials(responseWriter, request)
+ guest := guestHandler.findGuest(responseWriter, credentials)
+ expirationTime := guestHandler.setExpirationTime()
+ claims := guestHandler.createClaims(credentials, expirationTime)
+ key := guestHandler.readKey(responseWriter)
+ token := guestHandler.createToken(responseWriter, claims, key)
+ jsonBytes := guestHandler.marshalResponse(responseWriter, guest, token)
+ responseWriter.Write(jsonBytes)
+}
+
+func (guestHandler *guestHandler) decodeCredentials(responseWriter http.ResponseWriter, request *http.Request) guest.Credentials {
+ var credentials guest.Credentials
+ err := json.NewDecoder(request.Body).Decode(&credentials)
if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
+ http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+ log.Panic("Failed to unmarshal credentials")
}
- defer r.Body.Close()
+ defer request.Body.Close()
+ return credentials
+}
- weddingGuest, err := h.store.FindGuest(creds)
+func (guestHandler *guestHandler) findGuest(responseWriter http.ResponseWriter, credentials guest.Credentials) guest.Guest {
+ guest, err := guestHandler.store.FindGuest(credentials)
if err != nil {
- http.Error(w, err.Error(), http.StatusUnauthorized)
- return
+ http.Error(responseWriter, err.Error(), http.StatusUnauthorized)
+ log.Panic("Guest not found")
}
+ return guest
+}
- expirationTime := time.Now().Add(15 * time.Minute)
- claims := &guest.Claims{
- Credentials: creds,
+func (guestHandler *guestHandler) setExpirationTime() time.Time {
+ return time.Now().Add(15 * time.Minute)
+}
+
+func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, expirationTime time.Time) *guest.Claims {
+ return &guest.Claims{
+ Credentials: credentials,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
},
}
+}
+func (guestHandler *guestHandler) readKey(responseWriter http.ResponseWriter) []byte {
key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem")
if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
+ http.Error(responseWriter, err.Error(), http.StatusInternalServerError)
+ log.Panic("Failed to read secret key")
}
+ return key
+}
+func (guestHandler *guestHandler) createToken(responseWriter http.ResponseWriter, claims *guest.Claims, key []byte) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(key)
if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- loginResponse := &guest.LoginResponse{
- Guest: weddingGuest,
- Token: tokenString,
+ http.Error(responseWriter, err.Error(), http.StatusInternalServerError)
+ log.Panic("Failed to read token")
}
+ return tokenString
+}
+func (guestHandler *guestHandler) marshalResponse(responseWriter http.ResponseWriter, weddingGuest guest.Guest, token string) []byte {
+ loginResponse := guestHandler.createLoginResponse(weddingGuest, token)
jsonBytes, err := json.Marshal(loginResponse)
if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
+ http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+ log.Panic("Failed to marshal response")
}
-
- w.WriteHeader(http.StatusOK)
- w.Write(jsonBytes)
+ return jsonBytes
}
-func (h *guestHandler) getGuests(w http.ResponseWriter, _ *http.Request) {
- guests, err := h.store.Get()
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
+func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse {
+ return &guest.LoginResponse{
+ Guest: weddingGuest,
+ Token: token,
}
+}
- jsonBytes, err := json.Marshal(guests)
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
+func (guestHandler *guestHandler) putGuest(responseWriter http.ResponseWriter, request *http.Request) {
+ token := guestHandler.getToken(request)
+ claims := guestHandler.initializeClaims()
+ key := guestHandler.readKey(responseWriter)
+ guestHandler.checkClaims(responseWriter, token, claims, key)
+ guestHandler.findId(responseWriter, request)
+ guest := guestHandler.decodeGuest(responseWriter, request)
+ guestHandler.updateGuest(responseWriter, guest)
+}
- w.WriteHeader(http.StatusOK)
- w.Write(jsonBytes)
+func (guestHandler *guestHandler) getToken(request *http.Request) string {
+ return request.Header.Get("Authorization")
}
-func (h *guestHandler) createGuest(w http.ResponseWriter, r *http.Request) {
- var guest guest.Guest
- err := json.NewDecoder(r.Body).Decode(&guest)
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- defer r.Body.Close()
+func (guestHandler *guestHandler) initializeClaims() *guest.Claims {
+ return &guest.Claims{}
+}
- guestSlice, err := h.store.Get()
+func (guestHandler *guestHandler) checkClaims(responseWriter http.ResponseWriter, tokenString string, claims *guest.Claims, key []byte) *jwt.Token {
+ token, err := guestHandler.parseWithClaims(tokenString, claims, key)
if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
-
- for _, g := range guestSlice {
- if g.Id == guest.Id {
- http.Error(w, "Id already exists", http.StatusBadRequest)
- return
+ if err == jwt.ErrSignatureInvalid {
+ responseWriter.WriteHeader(http.StatusUnauthorized)
+ log.Panic("Invalid signature")
}
+ http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+ log.Panic("Failed to parse claims")
}
-
- err = h.store.Add(guest)
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
+ if !token.Valid {
+ responseWriter.WriteHeader(http.StatusUnauthorized)
+ log.Panic("Invalid token")
}
-
- w.WriteHeader(http.StatusOK)
+ return token
}
-func (h *guestHandler) updateGuest(w http.ResponseWriter, r *http.Request) {
- tokenString := r.Header.Get("Authorization")
- claims := &guest.Claims{}
-
- key, err := os.ReadFile("C:\\Users\\mhunt\\skey.pem")
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) {
+func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) {
+ return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
return key, nil
})
- if err != nil {
- if err == jwt.ErrSignatureInvalid {
- w.WriteHeader(http.StatusUnauthorized)
- return
- }
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- if !token.Valid {
- w.WriteHeader(http.StatusUnauthorized)
- return
- }
+}
- matches := guestIdRe.FindStringSubmatch(r.URL.Path)
+func (guestHandler *guestHandler) findId(responseWriter http.ResponseWriter, request *http.Request) {
+ matches := guestIdRe.FindStringSubmatch(request.URL.Path)
if len(matches) < 2 {
- http.Error(w, "No id found", http.StatusBadRequest)
- return
+ http.Error(responseWriter, "ID not found", http.StatusBadRequest)
+ log.Panic("ID not found")
}
+}
+func (guestHandler *guestHandler) decodeGuest(responseWriter http.ResponseWriter, request *http.Request) guest.Guest {
var guest guest.Guest
- err = json.NewDecoder(r.Body).Decode(&guest)
+ err := json.NewDecoder(request.Body).Decode(&guest)
if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
+ http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+ log.Panic("Invalid guest")
}
- defer r.Body.Close()
+ defer request.Body.Close()
+ return guest
+}
- err = h.store.Update(guest)
+func (guestHandler *guestHandler) updateGuest(responseWriter http.ResponseWriter, guest guest.Guest) {
+ err := guestHandler.store.Update(guest)
if err != nil {
- http.Error(w, "Cannot update guest", http.StatusBadRequest)
- return
+ http.Error(responseWriter, "Failed to update guest", http.StatusBadRequest)
+ log.Panic("Failed to update guest")
}
}
-func (h *guestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- switch {
- case r.Method == http.MethodOptions:
- w.WriteHeader(http.StatusOK)
- case r.Method == http.MethodPost && r.URL.Path == "/guest/login":
- h.login(w, r)
- case r.Method == http.MethodGet && guestRe.MatchString(r.URL.Path):
- h.getGuests(w, r)
- case r.Method == http.MethodPost && guestRe.MatchString(r.URL.Path):
- h.createGuest(w, r)
- case r.Method == http.MethodPut && guestIdRe.MatchString(r.URL.Path):
- h.updateGuest(w, r)
- default:
- w.WriteHeader(http.StatusNotFound)
- }
+func (guestHandler *guestHandler) getGuests(responseWriter http.ResponseWriter) {
+ guests := guestHandler.findGuests(responseWriter)
+ jsonBytes := guestHandler.marshalGuests(responseWriter, guests)
+ responseWriter.Write(jsonBytes)
}
-func enableCors(next http.Handler) http.Handler {
- allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
- allowedMethods := []string{"OPTIONS", "POST", "PUT"}
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- method := r.Header.Get("Access-Control-Request-Method")
- if isPreflight(r) && slices.Contains(allowedMethods, method) {
- w.Header().Add("Access-Control-Allow-Methods", method)
- }
+func (guestHandler *guestHandler) findGuests(responseWriter http.ResponseWriter) []guest.Guest {
+ guests, err := guestHandler.store.Get()
+ if err != nil {
+ http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+ log.Panic("Failed to find guests")
+ }
+ return guests
+}
- origin := r.Header.Get("Origin")
- if slices.Contains(allowedOrigins, origin) {
- w.Header().Add("Access-Control-Allow-Origin", origin)
- }
+func (guestHandler *guestHandler) marshalGuests(responseWriter http.ResponseWriter, guests []guest.Guest) []byte {
+ jsonBytes, err := json.Marshal(guests)
+ if err != nil {
+ http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+ log.Panic("Failed to marshal guests")
+ }
+ return jsonBytes
+}
- w.Header().Add("Access-Control-Allow-Headers", "*")
- next.ServeHTTP(w, r)
- })
+func (guestHandler *guestHandler) postGuest(responseWriter http.ResponseWriter, request *http.Request) {
+ guest := guestHandler.decodeGuest(responseWriter, request)
+ guests := guestHandler.findGuests(responseWriter)
+ guestHandler.checkExistingGuests(responseWriter, guests, guest)
+ guestHandler.addGuest(responseWriter, guest)
+ responseWriter.WriteHeader(http.StatusOK)
}
-func isPreflight(r *http.Request) bool {
- return r.Method == "OPTIONS" &&
- r.Header.Get("Origin") != "" &&
- r.Header.Get("Access-Control-Request-Method") != ""
+func (guestHandler *guestHandler) checkExistingGuests(responseWriter http.ResponseWriter, guests []guest.Guest, newGuest guest.Guest) {
+ for _, guest := range guests {
+ if guest.Id == newGuest.Id {
+ http.Error(responseWriter, "ID already exists", http.StatusBadRequest)
+ log.Panic("ID already exists")
+ }
+ }
}
-func main() {
- db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database))
+func (guestHandler *guestHandler) addGuest(responseWriter http.ResponseWriter, guest guest.Guest) {
+ err := guestHandler.store.Add(guest)
if err != nil {
- log.Fatal(err)
+ http.Error(responseWriter, err.Error(), http.StatusBadRequest)
+ return
}
- defer db.Close()
-
- store := guest.NewMemStore(db)
- guestHandler := newGuestHandler(store)
-
- mux := http.NewServeMux()
- mux.Handle("/guest/", guestHandler)
- log.Fatal(http.ListenAndServe(":8080", enableCors(mux)))
}