diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/guest/handler.go | 156 | ||||
-rw-r--r-- | server/guest/store.go | 20 | ||||
-rw-r--r-- | server/test/guest_test.go | 60 |
3 files changed, 172 insertions, 64 deletions
diff --git a/server/guest/handler.go b/server/guest/handler.go index 418c223..a14a039 100644 --- a/server/guest/handler.go +++ b/server/guest/handler.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "regexp" + "strconv" "time" "github.com/golang-jwt/jwt/v5" @@ -17,7 +18,7 @@ var ( ) type GuestHandler struct { - guestStore GuestStore + store GuestStore } type GuestStore interface { @@ -25,6 +26,7 @@ type GuestStore interface { Get() ([]Guest, error) Add(guest Guest) error Update(guest Guest) error + Delete(id int) error } type appError struct { @@ -39,25 +41,27 @@ func NewGuestHandler(guestStore GuestStore) *GuestHandler { } } -func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { +func (handler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { switch { case request.Method == http.MethodOptions: responseWriter.WriteHeader(http.StatusOK) case request.Method == http.MethodPost && request.URL.Path == "/guests/login": - guestHandler.handleLogIn(responseWriter, request) + handler.handleLogIn(responseWriter, request) case request.Method == http.MethodPut && guestIDRegex.MatchString(request.URL.Path): - guestHandler.handlePut(responseWriter, request) + handler.handlePut(responseWriter, request) case request.Method == http.MethodGet && guestRegex.MatchString(request.URL.Path): - guestHandler.handleGet(responseWriter, request) - case request.Method == http.MethodPost && guestIDRegex.MatchString(request.URL.Path): - guestHandler.handlePost(responseWriter, request) + handler.handleGet(responseWriter, request) + case request.Method == http.MethodPost && guestRegex.MatchString(request.URL.Path): + handler.handlePost(responseWriter, request) + case request.Method == http.MethodDelete && guestIDRegex.MatchString(request.URL.Path): + handler.handleDelete(responseWriter, request) default: responseWriter.WriteHeader(http.StatusNotFound) } } -func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { - token, err := guestHandler.logIn(request) +func (handler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { + token, err := handler.logIn(request) if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { @@ -65,16 +69,16 @@ func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter } } -func (guestHandler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) { - if err := guestHandler.putGuest(request); err != nil { +func (handler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) { + if err := handler.putGuest(request); err != nil { http.Error(responseWriter, err.Message, err.Code) } else { responseWriter.WriteHeader(http.StatusOK) } } -func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) { - guests, err := guestHandler.getGuests(request) +func (handler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) { + guests, err := handler.getGuests(request) if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { @@ -82,52 +86,60 @@ func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, } } -func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) { - if err := guestHandler.postGuest(request); err != nil { +func (handler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) { + if err := handler.postGuest(request); err != nil { http.Error(responseWriter, err.Message, err.Code) } else { responseWriter.WriteHeader(http.StatusOK) } } -func (guestHandler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) { - credentials, err := guestHandler.decodeCredentials(request) +func (handler *GuestHandler) handleDelete(responseWriter http.ResponseWriter, request *http.Request) { + if err := handler.deleteGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.WriteHeader(http.StatusOK) + } +} + +func (handler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) { + credentials, err := handler.decodeCredentials(request) if err != nil { return []byte{}, &appError{err, "failed to unmarshal credentials", http.StatusBadRequest} } - guest, err := guestHandler.guestStore.Find(credentials) + guest, err := handler.store.Find(credentials) if err != nil { return []byte{}, &appError{err, "guest not found", http.StatusUnauthorized} } - expirationTime := guestHandler.setExpirationTime() - claims := guestHandler.createClaims(credentials, expirationTime) - key, err := guestHandler.readGuestKey() + expirationTime := handler.setExpirationTime() + claims := handler.createClaims(credentials, expirationTime) + key, err := handler.readGuestKey() if err != nil { return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} } - token, err := guestHandler.createToken(claims, key) + token, err := handler.createToken(claims, key) if err != nil { return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError} } - jsonBytes, err := guestHandler.marshalResponse(guest, token) + jsonBytes, err := handler.marshalResponse(guest, token) if err != nil { return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError} } return jsonBytes, nil } -func (guestHandler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) { +func (handler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) { var credentials Credentials err := json.NewDecoder(request.Body).Decode(&credentials) defer request.Body.Close() return credentials, err } -func (guestHandler *GuestHandler) setExpirationTime() time.Time { +func (handler *GuestHandler) setExpirationTime() time.Time { return time.Now().Add(15 * time.Minute) } -func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims { +func (handler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims { return &Claims{ Credentials: credentials, RegisteredClaims: jwt.RegisteredClaims{ @@ -136,57 +148,57 @@ func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirati } } -func (guestHandler *GuestHandler) readGuestKey() ([]byte, error) { +func (handler *GuestHandler) readGuestKey() ([]byte, error) { // TODO: use properties file return os.ReadFile("C:\\Users\\mhunt\\guest.pem") } -func (guestHandler *GuestHandler) readAdminKey() ([]byte, error) { +func (handler *GuestHandler) readAdminKey() ([]byte, error) { return os.ReadFile("C:\\Users\\mhunt\\admin.pem") } -func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { +func (handler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(key) } -func (guestHandler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) { - loginResponse := guestHandler.createLoginResponse(guest, token) +func (handler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) { + loginResponse := handler.createLoginResponse(guest, token) return json.Marshal(loginResponse) } -func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse { +func (handler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse { return &LoginResponse{ Guest: weddingGuest, Token: token, } } -func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError { - guestKey, err := guestHandler.readGuestKey() +func (handler *GuestHandler) putGuest(request *http.Request) *appError { + guestKey, err := handler.readGuestKey() if err != nil { return &appError{err, "failed to read secret key", http.StatusInternalServerError} } - if err := guestHandler.validateToken(request, guestKey); err != nil { + if err := handler.validateToken(request, guestKey); err != nil { return err } - if guestHandler.findID(request) { + if handler.findID(request) { return &appError{errors.New("id not found"), "id not found", http.StatusNotFound} } - guest, err := guestHandler.decodeGuest(request) + guest, err := handler.decodeGuest(request) if err != nil { return &appError{err, "invalid guest", http.StatusBadRequest} } - if err := guestHandler.guestStore.Update(guest); err != nil { + if err := handler.store.Update(guest); err != nil { return &appError{err, "failed to update guest", http.StatusInternalServerError} } return nil } -func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byte) *appError { - authorizationHeader := guestHandler.getToken(request) - claims := guestHandler.newClaims() - token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) +func (handler *GuestHandler) validateToken(request *http.Request, key []byte) *appError { + authorizationHeader := handler.getToken(request) + claims := handler.newClaims() + token, err := handler.parseWithClaims(authorizationHeader, claims, key) if err != nil { if err == jwt.ErrSignatureInvalid { return &appError{err, "invalid signature", http.StatusUnauthorized} @@ -199,41 +211,41 @@ func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byt return nil } -func (guestHandler *GuestHandler) getToken(request *http.Request) string { +func (handler *GuestHandler) getToken(request *http.Request) string { return request.Header.Get("Authorization") } -func (guestHandler *GuestHandler) newClaims() *Claims { +func (handler *GuestHandler) newClaims() *Claims { return &Claims{} } -func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) { +func (handler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) { return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { return key, nil }) } -func (guestHandler *GuestHandler) findID(request *http.Request) bool { +func (handler *GuestHandler) findID(request *http.Request) bool { matches := guestIDRegex.FindStringSubmatch(request.URL.Path) return len(matches) < 2 } -func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) { +func (handler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) { var guest Guest err := json.NewDecoder(request.Body).Decode(&guest) defer request.Body.Close() return guest, err } -func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) { - adminKey, err := guestHandler.readAdminKey() +func (handler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) { + adminKey, err := handler.readAdminKey() if err != nil { return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} } - if err := guestHandler.validateToken(request, adminKey); err != nil { + if err := handler.validateToken(request, adminKey); err != nil { return []byte{}, err } - guests, err := guestHandler.guestStore.Get() + guests, err := handler.store.Get() if err != nil { return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError} } @@ -244,32 +256,32 @@ func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *app return jsonBytes, nil } -func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError { - adminKey, err := guestHandler.readAdminKey() +func (handler *GuestHandler) postGuest(request *http.Request) *appError { + adminKey, err := handler.readAdminKey() if err != nil { return &appError{err, "failed to read secret key", http.StatusInternalServerError} } - if err := guestHandler.validateToken(request, adminKey); err != nil { + if err := handler.validateToken(request, adminKey); err != nil { return err } - guest, err := guestHandler.decodeGuest(request) + guest, err := handler.decodeGuest(request) if err != nil { return &appError{err, "invalid guest", http.StatusBadRequest} } - guests, err := guestHandler.guestStore.Get() + guests, err := handler.store.Get() if err != nil { return &appError{err, "failed to get guests", http.StatusInternalServerError} } - if err := guestHandler.checkExistingGuests(guests, guest); err != nil { + if err := handler.checkExistingGuests(guests, guest); err != nil { return &appError{err, "id already exists", http.StatusConflict} } - if err := guestHandler.guestStore.Add(guest); err != nil { + if err := handler.store.Add(guest); err != nil { return &appError{err, "failed to add guest", http.StatusInternalServerError} } return nil } -func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error { +func (handler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error { for _, guest := range guests { if guest.ID == newGuest.ID { return errors.New("id already exists") @@ -277,3 +289,29 @@ func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest G } return nil } + +func (handler *GuestHandler) deleteGuest(request *http.Request) *appError { + adminKey, err := handler.readAdminKey() + if err != nil { + return &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := handler.validateToken(request, adminKey); err != nil { + return err + } + if handler.findID(request) { + return &appError{errors.New("id not found"), "id not found", http.StatusNotFound} + } + guestID, err := getID(request) + if err != nil { + return &appError{err, "failed to parse id", http.StatusInternalServerError} + } + err = handler.store.Delete(int(guestID)) + if err != nil { + return &appError{err, "failed to get guests", http.StatusInternalServerError} + } + return nil +} + +func getID(request *http.Request) (int64, error) { + return strconv.ParseInt(guestIDRegex.FindStringSubmatch(request.URL.Path)[1], 10, 32) +} diff --git a/server/guest/store.go b/server/guest/store.go index a5b9374..4290f8f 100644 --- a/server/guest/store.go +++ b/server/guest/store.go @@ -157,7 +157,7 @@ func (store Store) Update(guest Guest) error { if err := store.updateGuest(guest); err != nil { return err } - if err := store.deleteOldParty(guest.ID); err != nil { + if err := store.deleteParty(guest.ID); err != nil { return err } return store.insertParty(guest) @@ -171,7 +171,7 @@ func (store Store) updateGuest(guest Guest) error { return err } -func (store Store) deleteOldParty(guestID int) error { +func (store Store) deleteParty(guestID int) error { statement := "delete from party where guest_id = $1" _, err := store.database.Exec(context.Background(), statement, guestID) return err @@ -189,3 +189,19 @@ func (store Store) insertParty(guest Guest) error { } return nil } + +func (store Store) Delete(guestID int) error { + if err := store.deleteGuest(guestID); err != nil { + return err + } + if err := store.deleteParty(guestID); err != nil { + return err + } + return nil +} + +func (store Store) deleteGuest(guestID int) error { + statement := "delete from guest where id = $1" + _, err := store.database.Exec(context.Background(), statement, guestID) + return err +} diff --git a/server/test/guest_test.go b/server/test/guest_test.go index f6cc243..26bd143 100644 --- a/server/test/guest_test.go +++ b/server/test/guest_test.go @@ -78,9 +78,8 @@ func logInGuest(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginRe func addPartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) { response := httptest.NewRecorder() - guestWithParty := getGuestWithParty() putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guests/1", - strings.NewReader(guestWithParty)) + strings.NewReader(getGuestWithParty())) if err != nil { test.Error(err) } @@ -107,7 +106,10 @@ func logInAdmin(adminHandler *admin.AdminHandler, test *testing.T) admin.LoginRe func getGuests(guestHandler *guest.GuestHandler, token string, test *testing.T) []guest.Guest { response := httptest.NewRecorder() - getRequest, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/guests/", nil) + getRequest, err := http.NewRequest(http.MethodGet, "http://localhost:8080/guests/", nil) + if err != nil { + test.Error(err) + } getRequest.Header.Set("Authorization", token) guestHandler.ServeHTTP(response, getRequest) assertEquals(test, response.Result().StatusCode, 200) @@ -164,3 +166,55 @@ func getGuestWithoutParty() string { return `{"id":1,"firstName":"Michael","lastName":"Hunteman","attendance":"", "email":"","message":"","partySize":1,"partyList":[]}` } + +func TestAddGuest(test *testing.T) { + databasePool, err := pgxpool.New(context.Background(), + fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, host, port, database)) + if err != nil { + test.Error(err) + } + defer databasePool.Close() + guestStore := guest.NewStore(databasePool) + guestHandler := guest.NewGuestHandler(guestStore) + adminStore := admin.NewStore(databasePool) + adminHandler := admin.NewAdminHandler(adminStore, guestStore) + + adminLogin := logInAdmin(adminHandler, test) + adminToken := adminLogin.Token + guests := getGuests(guestHandler, adminToken, test) + assertEquals(test, findGuest(guests), findGuest([]guest.Guest{})) + postGuest(guestHandler, adminToken, test) + + guests = getGuests(guestHandler, adminToken, test) + guest := findGuest(guests) + assertEquals(test, guest.Attendance, "accept") + assertEquals(test, guest.Email, "mhunteman@cox.net") + assertEquals(test, guest.Message, "We'll be there!") + assertEquals(test, guest.PartySize, 2) + assertEquals(test, guest.PartyList[0].FirstName, "Madison") + assertEquals(test, guest.PartyList[0].LastName, "Rossitto") + deleteGuest(guestHandler, adminToken, test) +} + +func deleteGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) { + response := httptest.NewRecorder() + deleteRequest, err := http.NewRequest(http.MethodDelete, "http://localhost:8080/guests/1", nil) + if err != nil { + test.Error(err) + } + deleteRequest.Header.Set("Authorization", token) + guestHandler.ServeHTTP(response, deleteRequest) + assertEquals(test, response.Result().StatusCode, 200) +} + +func postGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) { + response := httptest.NewRecorder() + putRequest, err := http.NewRequest(http.MethodPost, "http://localhost:8080/guests/", + strings.NewReader(getGuestWithParty())) + if err != nil { + test.Error(err) + } + putRequest.Header.Set("Authorization", token) + guestHandler.ServeHTTP(response, putRequest) + assertEquals(test, response.Result().StatusCode, 200) +} |