summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/guest/handler.go156
-rw-r--r--server/guest/store.go20
-rw-r--r--server/test/guest_test.go60
3 files changed, 172 insertions, 64 deletions
diff --git a/server/guest/handler.go b/server/guest/handler.go
index 418c223..a14a039 100644
--- a/server/guest/handler.go
+++ b/server/guest/handler.go
@@ -6,6 +6,7 @@ import (
"net/http"
"os"
"regexp"
+ "strconv"
"time"
"github.com/golang-jwt/jwt/v5"
@@ -17,7 +18,7 @@ var (
)
type GuestHandler struct {
- guestStore GuestStore
+ store GuestStore
}
type GuestStore interface {
@@ -25,6 +26,7 @@ type GuestStore interface {
Get() ([]Guest, error)
Add(guest Guest) error
Update(guest Guest) error
+ Delete(id int) error
}
type appError struct {
@@ -39,25 +41,27 @@ func NewGuestHandler(guestStore GuestStore) *GuestHandler {
}
}
-func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
+func (handler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
switch {
case request.Method == http.MethodOptions:
responseWriter.WriteHeader(http.StatusOK)
case request.Method == http.MethodPost && request.URL.Path == "/guests/login":
- guestHandler.handleLogIn(responseWriter, request)
+ handler.handleLogIn(responseWriter, request)
case request.Method == http.MethodPut && guestIDRegex.MatchString(request.URL.Path):
- guestHandler.handlePut(responseWriter, request)
+ handler.handlePut(responseWriter, request)
case request.Method == http.MethodGet && guestRegex.MatchString(request.URL.Path):
- guestHandler.handleGet(responseWriter, request)
- case request.Method == http.MethodPost && guestIDRegex.MatchString(request.URL.Path):
- guestHandler.handlePost(responseWriter, request)
+ handler.handleGet(responseWriter, request)
+ case request.Method == http.MethodPost && guestRegex.MatchString(request.URL.Path):
+ handler.handlePost(responseWriter, request)
+ case request.Method == http.MethodDelete && guestIDRegex.MatchString(request.URL.Path):
+ handler.handleDelete(responseWriter, request)
default:
responseWriter.WriteHeader(http.StatusNotFound)
}
}
-func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
- token, err := guestHandler.logIn(request)
+func (handler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
+ token, err := handler.logIn(request)
if err != nil {
http.Error(responseWriter, err.Message, err.Code)
} else {
@@ -65,16 +69,16 @@ func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter
}
}
-func (guestHandler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) {
- if err := guestHandler.putGuest(request); err != nil {
+func (handler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) {
+ if err := handler.putGuest(request); err != nil {
http.Error(responseWriter, err.Message, err.Code)
} else {
responseWriter.WriteHeader(http.StatusOK)
}
}
-func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) {
- guests, err := guestHandler.getGuests(request)
+func (handler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) {
+ guests, err := handler.getGuests(request)
if err != nil {
http.Error(responseWriter, err.Message, err.Code)
} else {
@@ -82,52 +86,60 @@ func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter,
}
}
-func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) {
- if err := guestHandler.postGuest(request); err != nil {
+func (handler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) {
+ if err := handler.postGuest(request); err != nil {
http.Error(responseWriter, err.Message, err.Code)
} else {
responseWriter.WriteHeader(http.StatusOK)
}
}
-func (guestHandler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) {
- credentials, err := guestHandler.decodeCredentials(request)
+func (handler *GuestHandler) handleDelete(responseWriter http.ResponseWriter, request *http.Request) {
+ if err := handler.deleteGuest(request); err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.WriteHeader(http.StatusOK)
+ }
+}
+
+func (handler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) {
+ credentials, err := handler.decodeCredentials(request)
if err != nil {
return []byte{}, &appError{err, "failed to unmarshal credentials", http.StatusBadRequest}
}
- guest, err := guestHandler.guestStore.Find(credentials)
+ guest, err := handler.store.Find(credentials)
if err != nil {
return []byte{}, &appError{err, "guest not found", http.StatusUnauthorized}
}
- expirationTime := guestHandler.setExpirationTime()
- claims := guestHandler.createClaims(credentials, expirationTime)
- key, err := guestHandler.readGuestKey()
+ expirationTime := handler.setExpirationTime()
+ claims := handler.createClaims(credentials, expirationTime)
+ key, err := handler.readGuestKey()
if err != nil {
return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError}
}
- token, err := guestHandler.createToken(claims, key)
+ token, err := handler.createToken(claims, key)
if err != nil {
return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError}
}
- jsonBytes, err := guestHandler.marshalResponse(guest, token)
+ jsonBytes, err := handler.marshalResponse(guest, token)
if err != nil {
return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError}
}
return jsonBytes, nil
}
-func (guestHandler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) {
+func (handler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) {
var credentials Credentials
err := json.NewDecoder(request.Body).Decode(&credentials)
defer request.Body.Close()
return credentials, err
}
-func (guestHandler *GuestHandler) setExpirationTime() time.Time {
+func (handler *GuestHandler) setExpirationTime() time.Time {
return time.Now().Add(15 * time.Minute)
}
-func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims {
+func (handler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims {
return &Claims{
Credentials: credentials,
RegisteredClaims: jwt.RegisteredClaims{
@@ -136,57 +148,57 @@ func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirati
}
}
-func (guestHandler *GuestHandler) readGuestKey() ([]byte, error) {
+func (handler *GuestHandler) readGuestKey() ([]byte, error) {
// TODO: use properties file
return os.ReadFile("C:\\Users\\mhunt\\guest.pem")
}
-func (guestHandler *GuestHandler) readAdminKey() ([]byte, error) {
+func (handler *GuestHandler) readAdminKey() ([]byte, error) {
return os.ReadFile("C:\\Users\\mhunt\\admin.pem")
}
-func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) {
+func (handler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(key)
}
-func (guestHandler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) {
- loginResponse := guestHandler.createLoginResponse(guest, token)
+func (handler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) {
+ loginResponse := handler.createLoginResponse(guest, token)
return json.Marshal(loginResponse)
}
-func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse {
+func (handler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse {
return &LoginResponse{
Guest: weddingGuest,
Token: token,
}
}
-func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError {
- guestKey, err := guestHandler.readGuestKey()
+func (handler *GuestHandler) putGuest(request *http.Request) *appError {
+ guestKey, err := handler.readGuestKey()
if err != nil {
return &appError{err, "failed to read secret key", http.StatusInternalServerError}
}
- if err := guestHandler.validateToken(request, guestKey); err != nil {
+ if err := handler.validateToken(request, guestKey); err != nil {
return err
}
- if guestHandler.findID(request) {
+ if handler.findID(request) {
return &appError{errors.New("id not found"), "id not found", http.StatusNotFound}
}
- guest, err := guestHandler.decodeGuest(request)
+ guest, err := handler.decodeGuest(request)
if err != nil {
return &appError{err, "invalid guest", http.StatusBadRequest}
}
- if err := guestHandler.guestStore.Update(guest); err != nil {
+ if err := handler.store.Update(guest); err != nil {
return &appError{err, "failed to update guest", http.StatusInternalServerError}
}
return nil
}
-func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byte) *appError {
- authorizationHeader := guestHandler.getToken(request)
- claims := guestHandler.newClaims()
- token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key)
+func (handler *GuestHandler) validateToken(request *http.Request, key []byte) *appError {
+ authorizationHeader := handler.getToken(request)
+ claims := handler.newClaims()
+ token, err := handler.parseWithClaims(authorizationHeader, claims, key)
if err != nil {
if err == jwt.ErrSignatureInvalid {
return &appError{err, "invalid signature", http.StatusUnauthorized}
@@ -199,41 +211,41 @@ func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byt
return nil
}
-func (guestHandler *GuestHandler) getToken(request *http.Request) string {
+func (handler *GuestHandler) getToken(request *http.Request) string {
return request.Header.Get("Authorization")
}
-func (guestHandler *GuestHandler) newClaims() *Claims {
+func (handler *GuestHandler) newClaims() *Claims {
return &Claims{}
}
-func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) {
+func (handler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) {
return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
return key, nil
})
}
-func (guestHandler *GuestHandler) findID(request *http.Request) bool {
+func (handler *GuestHandler) findID(request *http.Request) bool {
matches := guestIDRegex.FindStringSubmatch(request.URL.Path)
return len(matches) < 2
}
-func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) {
+func (handler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) {
var guest Guest
err := json.NewDecoder(request.Body).Decode(&guest)
defer request.Body.Close()
return guest, err
}
-func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) {
- adminKey, err := guestHandler.readAdminKey()
+func (handler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) {
+ adminKey, err := handler.readAdminKey()
if err != nil {
return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError}
}
- if err := guestHandler.validateToken(request, adminKey); err != nil {
+ if err := handler.validateToken(request, adminKey); err != nil {
return []byte{}, err
}
- guests, err := guestHandler.guestStore.Get()
+ guests, err := handler.store.Get()
if err != nil {
return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError}
}
@@ -244,32 +256,32 @@ func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *app
return jsonBytes, nil
}
-func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError {
- adminKey, err := guestHandler.readAdminKey()
+func (handler *GuestHandler) postGuest(request *http.Request) *appError {
+ adminKey, err := handler.readAdminKey()
if err != nil {
return &appError{err, "failed to read secret key", http.StatusInternalServerError}
}
- if err := guestHandler.validateToken(request, adminKey); err != nil {
+ if err := handler.validateToken(request, adminKey); err != nil {
return err
}
- guest, err := guestHandler.decodeGuest(request)
+ guest, err := handler.decodeGuest(request)
if err != nil {
return &appError{err, "invalid guest", http.StatusBadRequest}
}
- guests, err := guestHandler.guestStore.Get()
+ guests, err := handler.store.Get()
if err != nil {
return &appError{err, "failed to get guests", http.StatusInternalServerError}
}
- if err := guestHandler.checkExistingGuests(guests, guest); err != nil {
+ if err := handler.checkExistingGuests(guests, guest); err != nil {
return &appError{err, "id already exists", http.StatusConflict}
}
- if err := guestHandler.guestStore.Add(guest); err != nil {
+ if err := handler.store.Add(guest); err != nil {
return &appError{err, "failed to add guest", http.StatusInternalServerError}
}
return nil
}
-func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error {
+func (handler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error {
for _, guest := range guests {
if guest.ID == newGuest.ID {
return errors.New("id already exists")
@@ -277,3 +289,29 @@ func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest G
}
return nil
}
+
+func (handler *GuestHandler) deleteGuest(request *http.Request) *appError {
+ adminKey, err := handler.readAdminKey()
+ if err != nil {
+ return &appError{err, "failed to read secret key", http.StatusInternalServerError}
+ }
+ if err := handler.validateToken(request, adminKey); err != nil {
+ return err
+ }
+ if handler.findID(request) {
+ return &appError{errors.New("id not found"), "id not found", http.StatusNotFound}
+ }
+ guestID, err := getID(request)
+ if err != nil {
+ return &appError{err, "failed to parse id", http.StatusInternalServerError}
+ }
+ err = handler.store.Delete(int(guestID))
+ if err != nil {
+ return &appError{err, "failed to get guests", http.StatusInternalServerError}
+ }
+ return nil
+}
+
+func getID(request *http.Request) (int64, error) {
+ return strconv.ParseInt(guestIDRegex.FindStringSubmatch(request.URL.Path)[1], 10, 32)
+}
diff --git a/server/guest/store.go b/server/guest/store.go
index a5b9374..4290f8f 100644
--- a/server/guest/store.go
+++ b/server/guest/store.go
@@ -157,7 +157,7 @@ func (store Store) Update(guest Guest) error {
if err := store.updateGuest(guest); err != nil {
return err
}
- if err := store.deleteOldParty(guest.ID); err != nil {
+ if err := store.deleteParty(guest.ID); err != nil {
return err
}
return store.insertParty(guest)
@@ -171,7 +171,7 @@ func (store Store) updateGuest(guest Guest) error {
return err
}
-func (store Store) deleteOldParty(guestID int) error {
+func (store Store) deleteParty(guestID int) error {
statement := "delete from party where guest_id = $1"
_, err := store.database.Exec(context.Background(), statement, guestID)
return err
@@ -189,3 +189,19 @@ func (store Store) insertParty(guest Guest) error {
}
return nil
}
+
+func (store Store) Delete(guestID int) error {
+ if err := store.deleteGuest(guestID); err != nil {
+ return err
+ }
+ if err := store.deleteParty(guestID); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (store Store) deleteGuest(guestID int) error {
+ statement := "delete from guest where id = $1"
+ _, err := store.database.Exec(context.Background(), statement, guestID)
+ return err
+}
diff --git a/server/test/guest_test.go b/server/test/guest_test.go
index f6cc243..26bd143 100644
--- a/server/test/guest_test.go
+++ b/server/test/guest_test.go
@@ -78,9 +78,8 @@ func logInGuest(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginRe
func addPartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) {
response := httptest.NewRecorder()
- guestWithParty := getGuestWithParty()
putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guests/1",
- strings.NewReader(guestWithParty))
+ strings.NewReader(getGuestWithParty()))
if err != nil {
test.Error(err)
}
@@ -107,7 +106,10 @@ func logInAdmin(adminHandler *admin.AdminHandler, test *testing.T) admin.LoginRe
func getGuests(guestHandler *guest.GuestHandler, token string, test *testing.T) []guest.Guest {
response := httptest.NewRecorder()
- getRequest, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/guests/", nil)
+ getRequest, err := http.NewRequest(http.MethodGet, "http://localhost:8080/guests/", nil)
+ if err != nil {
+ test.Error(err)
+ }
getRequest.Header.Set("Authorization", token)
guestHandler.ServeHTTP(response, getRequest)
assertEquals(test, response.Result().StatusCode, 200)
@@ -164,3 +166,55 @@ func getGuestWithoutParty() string {
return `{"id":1,"firstName":"Michael","lastName":"Hunteman","attendance":"",
"email":"","message":"","partySize":1,"partyList":[]}`
}
+
+func TestAddGuest(test *testing.T) {
+ databasePool, err := pgxpool.New(context.Background(),
+ fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, host, port, database))
+ if err != nil {
+ test.Error(err)
+ }
+ defer databasePool.Close()
+ guestStore := guest.NewStore(databasePool)
+ guestHandler := guest.NewGuestHandler(guestStore)
+ adminStore := admin.NewStore(databasePool)
+ adminHandler := admin.NewAdminHandler(adminStore, guestStore)
+
+ adminLogin := logInAdmin(adminHandler, test)
+ adminToken := adminLogin.Token
+ guests := getGuests(guestHandler, adminToken, test)
+ assertEquals(test, findGuest(guests), findGuest([]guest.Guest{}))
+ postGuest(guestHandler, adminToken, test)
+
+ guests = getGuests(guestHandler, adminToken, test)
+ guest := findGuest(guests)
+ assertEquals(test, guest.Attendance, "accept")
+ assertEquals(test, guest.Email, "mhunteman@cox.net")
+ assertEquals(test, guest.Message, "We'll be there!")
+ assertEquals(test, guest.PartySize, 2)
+ assertEquals(test, guest.PartyList[0].FirstName, "Madison")
+ assertEquals(test, guest.PartyList[0].LastName, "Rossitto")
+ deleteGuest(guestHandler, adminToken, test)
+}
+
+func deleteGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) {
+ response := httptest.NewRecorder()
+ deleteRequest, err := http.NewRequest(http.MethodDelete, "http://localhost:8080/guests/1", nil)
+ if err != nil {
+ test.Error(err)
+ }
+ deleteRequest.Header.Set("Authorization", token)
+ guestHandler.ServeHTTP(response, deleteRequest)
+ assertEquals(test, response.Result().StatusCode, 200)
+}
+
+func postGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) {
+ response := httptest.NewRecorder()
+ putRequest, err := http.NewRequest(http.MethodPost, "http://localhost:8080/guests/",
+ strings.NewReader(getGuestWithParty()))
+ if err != nil {
+ test.Error(err)
+ }
+ putRequest.Header.Set("Authorization", token)
+ guestHandler.ServeHTTP(response, putRequest)
+ assertEquals(test, response.Result().StatusCode, 200)
+}