summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorMichael Hunteman <michael@huntm.net>2024-08-18 17:48:37 -0700
committerMichael Hunteman <michael@huntm.net>2024-08-18 17:52:35 -0700
commit42aa6dc74a736f8c96bd1b1bdf37ad3ff905f08c (patch)
tree0c5232e33fdef411bcc9dc81396c6a397e04c652 /server
parentf9b3c68e31a90e8091da69b9defcb6641697e9c8 (diff)
Add integration test
Diffstat (limited to 'server')
-rw-r--r--server/cmd/main.go293
-rw-r--r--server/guest/handler.go268
-rw-r--r--server/guest/store.go2
-rw-r--r--server/test/guest_test.go129
4 files changed, 415 insertions, 277 deletions
diff --git a/server/cmd/main.go b/server/cmd/main.go
index 66f31b2..439ed0b 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -2,48 +2,23 @@ package main
import (
"context"
- "encoding/json"
- "errors"
"fmt"
"log"
"net/http"
"os"
- "regexp"
"slices"
- "time"
- "github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgxpool"
"git.huntm.net/wedding/server/guest"
)
-type guestHandler struct {
- store guestStore
-}
-
-type guestStore interface {
- FindGuest(credentials guest.Credentials) (guest.Guest, error)
- Get() ([]guest.Guest, error)
- Add(guest guest.Guest) error
- Update(guest guest.Guest) error
-}
-
-type appError struct {
- Error error
- Message string
- Code int
-}
-
var (
user = os.Getenv("USER")
pass = os.Getenv("PASS")
host = "localhost"
port = "5432"
database = "postgres"
-
- guestRe = regexp.MustCompile(`^/guest/*$`)
- guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
)
func main() {
@@ -54,272 +29,38 @@ func main() {
defer db.Close()
store := guest.NewMemStore(db)
- guestHandler := createGuestHandler(store)
+ guestHandler := guest.NewGuestHandler(store)
mux := http.NewServeMux()
mux.Handle("/guest/", guestHandler)
- log.Fatal(http.ListenAndServe(":8080", enableCors(mux)))
+ log.Fatal(http.ListenAndServe(":8080", enableCORS(mux)))
}
-func createGuestHandler(s guestStore) *guestHandler {
- return &guestHandler{
- store: s,
- }
-}
-
-func enableCors(next http.Handler) http.Handler {
+func enableCORS(handler http.Handler) http.Handler {
allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
allowedMethods := []string{"OPTIONS", "POST", "PUT"}
- return serveHttp(next, allowedOrigins, allowedMethods)
+ return serveHTTP(handler, allowedOrigins, allowedMethods)
}
-func serveHttp(next http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- method := r.Header.Get("Access-Control-Request-Method")
- if isPreflight(r) && slices.Contains(allowedMethods, method) {
- w.Header().Add("Access-Control-Allow-Methods", method)
+func serveHTTP(handler http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler {
+ return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) {
+ method := request.Header.Get("Access-Control-Request-Method")
+ if isPreflight(request) && slices.Contains(allowedMethods, method) {
+ responseWriter.Header().Add("Access-Control-Allow-Methods", method)
}
- origin := r.Header.Get("Origin")
+ origin := request.Header.Get("Origin")
if slices.Contains(allowedOrigins, origin) {
- w.Header().Add("Access-Control-Allow-Origin", origin)
+ responseWriter.Header().Add("Access-Control-Allow-Origin", origin)
}
- w.Header().Add("Access-Control-Allow-Headers", "*")
- next.ServeHTTP(w, r)
+ responseWriter.Header().Add("Access-Control-Allow-Headers", "*")
+ handler.ServeHTTP(responseWriter, request)
})
}
-func isPreflight(r *http.Request) bool {
- return r.Method == "OPTIONS" &&
- r.Header.Get("Origin") != "" &&
- r.Header.Get("Access-Control-Request-Method") != ""
-}
-
-func (guestHandler *guestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
- switch {
- case request.Method == http.MethodOptions:
- responseWriter.WriteHeader(http.StatusOK)
- case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
- guestHandler.handleLogIn(responseWriter, request)
- case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
- guestHandler.handlePut(responseWriter, request)
- case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
- guestHandler.handleGet(responseWriter, request)
- case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
- guestHandler.handlePost(responseWriter, request)
- default:
- responseWriter.WriteHeader(http.StatusNotFound)
- }
-}
-
-func (guestHandler *guestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
- token, err := guestHandler.logInGuest(request)
- if err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.Write(token)
- }
-}
-
-func (guestHandler *guestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) {
- if err := guestHandler.putGuest(request); err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.WriteHeader(http.StatusOK)
- }
-}
-
-func (guestHandler *guestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) {
- guests, err := guestHandler.getGuests(request)
- if err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.Write(guests)
- }
-}
-
-func (guestHandler *guestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) {
- if err := guestHandler.postGuest(request); err != nil {
- http.Error(responseWriter, err.Message, err.Code)
- } else {
- responseWriter.WriteHeader(http.StatusOK)
- }
-}
-
-func (guestHandler *guestHandler) logInGuest(request *http.Request) ([]byte, *appError) {
- credentials, err := guestHandler.decodeCredentials(request)
- if err != nil {
- return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest}
- }
- guest, err := guestHandler.store.FindGuest(credentials)
- if err != nil {
- return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized}
- }
- expirationTime := guestHandler.setExpirationTime()
- claims := guestHandler.createClaims(credentials, expirationTime)
- key, err := guestHandler.readKey()
- if err != nil {
- return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError}
- }
- token, err := guestHandler.createToken(claims, key)
- if err != nil {
- return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError}
- }
- jsonBytes, err := guestHandler.marshalResponse(guest, token)
- if err != nil {
- return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError}
- }
- return jsonBytes, nil
-}
-
-func (guestHandler *guestHandler) decodeCredentials(request *http.Request) (guest.Credentials, error) {
- var credentials guest.Credentials
- err := json.NewDecoder(request.Body).Decode(&credentials)
- defer request.Body.Close()
- return credentials, err
-}
-
-func (guestHandler *guestHandler) setExpirationTime() time.Time {
- return time.Now().Add(15 * time.Minute)
-}
-
-func (guestHandler *guestHandler) createClaims(credentials guest.Credentials, expirationTime time.Time) *guest.Claims {
- return &guest.Claims{
- Credentials: credentials,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(expirationTime),
- },
- }
-}
-
-func (guestHandler *guestHandler) readKey() ([]byte, error) {
- // TODO: use properties file
- return os.ReadFile("C:\\Users\\mhunt\\skey.pem")
-}
-
-func (guestHandler *guestHandler) createToken(claims *guest.Claims, key []byte) (string, error) {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString(key)
-}
-
-func (guestHandler *guestHandler) marshalResponse(guest guest.Guest, token string) ([]byte, error) {
- loginResponse := guestHandler.createLoginResponse(guest, token)
- return json.Marshal(loginResponse)
-}
-
-func (guestHandler *guestHandler) createLoginResponse(weddingGuest guest.Guest, token string) *guest.LoginResponse {
- return &guest.LoginResponse{
- Guest: weddingGuest,
- Token: token,
- }
-}
-
-func (guestHandler *guestHandler) putGuest(request *http.Request) *appError {
- if err := guestHandler.validateToken(request); err != nil {
- return err
- }
- if guestHandler.findId(request) {
- return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound}
- }
- guest, err := guestHandler.decodeGuest(request)
- if err != nil {
- return &appError{err, "Invalid guest", http.StatusBadRequest}
- }
- if err := guestHandler.store.Update(guest); err != nil {
- return &appError{err, "Failed to update guest", http.StatusInternalServerError}
- }
- return nil
-}
-
-func (guestHandler *guestHandler) validateToken(request *http.Request) *appError {
- authorizationHeader := guestHandler.getToken(request)
- claims := guestHandler.initializeClaims()
- key, err := guestHandler.readKey()
- if err != nil {
- return &appError{err, "Failed to read secret key", http.StatusInternalServerError}
- }
- token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key)
- if err != nil {
- if err == jwt.ErrSignatureInvalid {
- return &appError{err, "Invalid signature", http.StatusUnauthorized}
- }
- return &appError{err, "Failed to parse claims", http.StatusBadRequest}
- }
- if !token.Valid {
- return &appError{err, "Invalid token", http.StatusUnauthorized}
- }
- return nil
-}
-
-func (guestHandler *guestHandler) getToken(request *http.Request) string {
- return request.Header.Get("Authorization")
-}
-
-func (guestHandler *guestHandler) initializeClaims() *guest.Claims {
- return &guest.Claims{}
-}
-
-func (guestHandler *guestHandler) parseWithClaims(token string, claims *guest.Claims, key []byte) (*jwt.Token, error) {
- return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
- return key, nil
- })
-}
-
-func (guestHandler *guestHandler) findId(request *http.Request) bool {
- matches := guestIdRe.FindStringSubmatch(request.URL.Path)
- return len(matches) < 2
-}
-
-func (guestHandler *guestHandler) decodeGuest(request *http.Request) (guest.Guest, error) {
- var guest guest.Guest
- err := json.NewDecoder(request.Body).Decode(&guest)
- defer request.Body.Close()
- return guest, err
-}
-
-func (guestHandler *guestHandler) getGuests(request *http.Request) ([]byte, *appError) {
- // TODO: check with admin token
- if err := guestHandler.validateToken(request); err != nil {
- return []byte{}, err
- }
- guests, err := guestHandler.store.Get()
- if err != nil {
- return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError}
- }
- jsonBytes, err := json.Marshal(guests)
- if err != nil {
- return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError}
- }
- return jsonBytes, nil
-}
-
-func (guestHandler *guestHandler) postGuest(request *http.Request) *appError {
- if err := guestHandler.validateToken(request); err != nil {
- return err
- }
- guest, err := guestHandler.decodeGuest(request)
- if err != nil {
- return &appError{err, "Invalid guest", http.StatusBadRequest}
- }
- guests, err := guestHandler.store.Get()
- if err != nil {
- return &appError{err, "Failed to get guests", http.StatusInternalServerError}
- }
- if err := guestHandler.checkExistingGuests(guests, guest); err != nil {
- return &appError{err, "ID already exists", http.StatusConflict}
- }
- if err := guestHandler.store.Add(guest); err != nil {
- return &appError{err, "Failed to add guest", http.StatusInternalServerError}
- }
- return nil
-}
-
-func (guestHandler *guestHandler) checkExistingGuests(guests []guest.Guest, newGuest guest.Guest) error {
- for _, guest := range guests {
- if guest.Id == newGuest.Id {
- return errors.New("ID already exists")
- }
- }
- return nil
+func isPreflight(request *http.Request) bool {
+ return request.Method == "OPTIONS" &&
+ request.Header.Get("Origin") != "" &&
+ request.Header.Get("Access-Control-Request-Method") != ""
}
diff --git a/server/guest/handler.go b/server/guest/handler.go
new file mode 100644
index 0000000..46b8a45
--- /dev/null
+++ b/server/guest/handler.go
@@ -0,0 +1,268 @@
+package guest
+
+import (
+ "encoding/json"
+ "errors"
+ "net/http"
+ "os"
+ "regexp"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
+
+var (
+ guestRe = regexp.MustCompile(`^/guest/*$`)
+ guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
+)
+
+type GuestHandler struct {
+ store guestStore
+}
+
+type guestStore interface {
+ Find(credentials Credentials) (Guest, error)
+ Get() ([]Guest, error)
+ Add(guest Guest) error
+ Update(guest Guest) error
+}
+
+type appError struct {
+ Error error
+ Message string
+ Code int
+}
+
+func NewGuestHandler(s guestStore) *GuestHandler {
+ return &GuestHandler{
+ store: s,
+ }
+}
+
+func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
+ switch {
+ case request.Method == http.MethodOptions:
+ responseWriter.WriteHeader(http.StatusOK)
+ case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
+ guestHandler.handleLogIn(responseWriter, request)
+ case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
+ guestHandler.handlePut(responseWriter, request)
+ case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
+ guestHandler.handleGet(responseWriter, request)
+ case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
+ guestHandler.handlePost(responseWriter, request)
+ default:
+ responseWriter.WriteHeader(http.StatusNotFound)
+ }
+}
+
+func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
+ token, err := guestHandler.logInGuest(request)
+ if err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.Write(token)
+ }
+}
+
+func (guestHandler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) {
+ if err := guestHandler.putGuest(request); err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.WriteHeader(http.StatusOK)
+ }
+}
+
+func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) {
+ guests, err := guestHandler.getGuests(request)
+ if err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.Write(guests)
+ }
+}
+
+func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) {
+ if err := guestHandler.postGuest(request); err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.WriteHeader(http.StatusOK)
+ }
+}
+
+func (guestHandler *GuestHandler) logInGuest(request *http.Request) ([]byte, *appError) {
+ credentials, err := guestHandler.decodeCredentials(request)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest}
+ }
+ guest, err := guestHandler.store.Find(credentials)
+ if err != nil {
+ return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized}
+ }
+ expirationTime := guestHandler.setExpirationTime()
+ claims := guestHandler.createClaims(credentials, expirationTime)
+ key, err := guestHandler.readKey()
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError}
+ }
+ token, err := guestHandler.createToken(claims, key)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError}
+ }
+ jsonBytes, err := guestHandler.marshalResponse(guest, token)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError}
+ }
+ return jsonBytes, nil
+}
+
+func (guestHandler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) {
+ var credentials Credentials
+ err := json.NewDecoder(request.Body).Decode(&credentials)
+ defer request.Body.Close()
+ return credentials, err
+}
+
+func (guestHandler *GuestHandler) setExpirationTime() time.Time {
+ return time.Now().Add(15 * time.Minute)
+}
+
+func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims {
+ return &Claims{
+ Credentials: credentials,
+ RegisteredClaims: jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(expirationTime),
+ },
+ }
+}
+
+func (guestHandler *GuestHandler) readKey() ([]byte, error) {
+ // TODO: use properties file
+ return os.ReadFile("C:\\Users\\mhunt\\skey.pem")
+}
+
+func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ return token.SignedString(key)
+}
+
+func (guestHandler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) {
+ loginResponse := guestHandler.createLoginResponse(guest, token)
+ return json.Marshal(loginResponse)
+}
+
+func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse {
+ return &LoginResponse{
+ Guest: weddingGuest,
+ Token: token,
+ }
+}
+
+func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError {
+ if err := guestHandler.validateToken(request); err != nil {
+ return err
+ }
+ if guestHandler.findId(request) {
+ return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound}
+ }
+ guest, err := guestHandler.decodeGuest(request)
+ if err != nil {
+ return &appError{err, "Invalid guest", http.StatusBadRequest}
+ }
+ if err := guestHandler.store.Update(guest); err != nil {
+ return &appError{err, "Failed to update guest", http.StatusInternalServerError}
+ }
+ return nil
+}
+
+func (guestHandler *GuestHandler) validateToken(request *http.Request) *appError {
+ authorizationHeader := guestHandler.getToken(request)
+ claims := guestHandler.initializeClaims()
+ key, err := guestHandler.readKey()
+ if err != nil {
+ return &appError{err, "Failed to read secret key", http.StatusInternalServerError}
+ }
+ token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key)
+ if err != nil {
+ if err == jwt.ErrSignatureInvalid {
+ return &appError{err, "Invalid signature", http.StatusUnauthorized}
+ }
+ return &appError{err, "Failed to parse claims", http.StatusBadRequest}
+ }
+ if !token.Valid {
+ return &appError{err, "Invalid token", http.StatusUnauthorized}
+ }
+ return nil
+}
+
+func (guestHandler *GuestHandler) getToken(request *http.Request) string {
+ return request.Header.Get("Authorization")
+}
+
+func (guestHandler *GuestHandler) initializeClaims() *Claims {
+ return &Claims{}
+}
+
+func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) {
+ return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
+ return key, nil
+ })
+}
+
+func (guestHandler *GuestHandler) findId(request *http.Request) bool {
+ matches := guestIdRe.FindStringSubmatch(request.URL.Path)
+ return len(matches) < 2
+}
+
+func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) {
+ var guest Guest
+ err := json.NewDecoder(request.Body).Decode(&guest)
+ defer request.Body.Close()
+ return guest, err
+}
+
+func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) {
+ // TODO: check with admin token
+ if err := guestHandler.validateToken(request); err != nil {
+ return []byte{}, err
+ }
+ guests, err := guestHandler.store.Get()
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError}
+ }
+ jsonBytes, err := json.Marshal(guests)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError}
+ }
+ return jsonBytes, nil
+}
+
+func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError {
+ if err := guestHandler.validateToken(request); err != nil {
+ return err
+ }
+ guest, err := guestHandler.decodeGuest(request)
+ if err != nil {
+ return &appError{err, "Invalid guest", http.StatusBadRequest}
+ }
+ guests, err := guestHandler.store.Get()
+ if err != nil {
+ return &appError{err, "Failed to get guests", http.StatusInternalServerError}
+ }
+ if err := guestHandler.checkExistingGuests(guests, guest); err != nil {
+ return &appError{err, "ID already exists", http.StatusConflict}
+ }
+ if err := guestHandler.store.Add(guest); err != nil {
+ return &appError{err, "Failed to add guest", http.StatusInternalServerError}
+ }
+ return nil
+}
+
+func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error {
+ for _, guest := range guests {
+ if guest.Id == newGuest.Id {
+ return errors.New("ID already exists")
+ }
+ }
+ return nil
+}
diff --git a/server/guest/store.go b/server/guest/store.go
index db9fc3d..1a07161 100644
--- a/server/guest/store.go
+++ b/server/guest/store.go
@@ -17,7 +17,7 @@ func NewMemStore(db *pgxpool.Pool) *MemStore {
}
}
-func (m MemStore) FindGuest(creds Credentials) (Guest, error) {
+func (m MemStore) Find(creds Credentials) (Guest, error) {
rows, err := m.db.Query(context.Background(), "select * from guest")
var guest Guest
if err != nil {
diff --git a/server/test/guest_test.go b/server/test/guest_test.go
new file mode 100644
index 0000000..c675ddb
--- /dev/null
+++ b/server/test/guest_test.go
@@ -0,0 +1,129 @@
+package test
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+
+ "git.huntm.net/wedding/server/guest"
+ "github.com/jackc/pgx/v5/pgxpool"
+)
+
+var (
+ user = os.Getenv("USER")
+ pass = os.Getenv("PASS")
+ host = "localhost"
+ port = "5432"
+ database = "postgres"
+)
+
+func TestUpdateRSVP(test *testing.T) {
+ db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database))
+ if err != nil {
+ test.Error(err)
+ }
+ defer db.Close()
+ store := guest.NewMemStore(db)
+ guestHandler := guest.NewGuestHandler(store)
+
+ login := logIn(guestHandler, test)
+ token := login.Token
+
+ addPartyGuest(guestHandler, token, test)
+ guestSlice := getGuests(guestHandler, token, test)
+ guest := guestSlice[0]
+ assertEquals(test, guest.Attendance, "accept")
+ assertEquals(test, guest.Email, "mhunteman@yahoo.com")
+ assertEquals(test, guest.Message, "We'll be there!")
+ assertEquals(test, guest.PartySize, 2)
+ assertEquals(test, guest.PartyList[0].FirstName, "Madison")
+ assertEquals(test, guest.PartyList[0].LastName, "Rossitto")
+
+ deletePartyGuest(guestHandler, token, test)
+ guestSlice = getGuests(guestHandler, token, test)
+ guest = guestSlice[0]
+ assertEquals(test, guest.Attendance, "decline")
+ assertEquals(test, guest.Email, "")
+ assertEquals(test, guest.Message, "")
+ assertEquals(test, guest.PartySize, 1)
+}
+
+func logIn(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginResponse {
+ response := httptest.NewRecorder()
+ loginRequest, err := http.NewRequest(http.MethodPost, "http://localhost:8080/guest/login",
+ strings.NewReader(getCredentials()))
+ if err != nil {
+ test.Error(err)
+ }
+ guestHandler.ServeHTTP(response, loginRequest)
+ assertEquals(test, response.Result().StatusCode, 200)
+ var login guest.LoginResponse
+ if err = json.NewDecoder(response.Body).Decode(&login); err != nil {
+ log.Fatal(err)
+ }
+ return login
+}
+
+func addPartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) {
+ response := httptest.NewRecorder()
+ guestWithParty := getGuestWithParty()
+ putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", strings.NewReader(guestWithParty))
+ if err != nil {
+ test.Error(err)
+ }
+ putRequest.Header.Set("Authorization", token)
+ guestHandler.ServeHTTP(response, putRequest)
+ assertEquals(test, response.Result().StatusCode, 200)
+}
+
+func getGuests(guestHandler *guest.GuestHandler, token string, test *testing.T) []guest.Guest {
+ response := httptest.NewRecorder()
+ getRequest, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/guest/", nil)
+ getRequest.Header.Set("Authorization", token)
+ guestHandler.ServeHTTP(response, getRequest)
+ assertEquals(test, response.Result().StatusCode, 200)
+ var guestSlice []guest.Guest
+ if err := json.NewDecoder(response.Body).Decode(&guestSlice); err != nil {
+ test.Error(err)
+ }
+ return guestSlice
+}
+
+func deletePartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) {
+ response := httptest.NewRecorder()
+ putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", strings.NewReader(getGuestWithoutParty()))
+ if err != nil {
+ test.Error(err)
+ }
+ putRequest.Header.Set("Authorization", token)
+ guestHandler.ServeHTTP(response, putRequest)
+ assertEquals(test, response.Result().StatusCode, 200)
+}
+
+func assertEquals(test testing.TB, actual any, expected any) {
+ test.Helper()
+ if actual != expected {
+ test.Errorf("Actual: %s, expected: %s", actual, expected)
+ }
+}
+
+func getCredentials() string {
+ return "{ \"firstName\": \"Michael\", \"lastName\": \"Hunteman\"}"
+}
+
+func getGuestWithParty() string {
+ return `{"id":1,"firstName":"Michael","lastName":"Hunteman", "attendance":"accept",
+ "email":"mhunteman@yahoo.com","message":"We'll be there!","partySize":2,
+ "partyList":[{"firstName":"Madison","lastName":"Rossitto"}]}`
+}
+
+func getGuestWithoutParty() string {
+ return `{"id":1,"firstName":"Michael","lastName":"Hunteman","attendance":"decline",
+ "email":"","message":"","partySize":1,"partyList":[]}`
+}