diff options
Diffstat (limited to 'server/guest')
-rw-r--r-- | server/guest/handler.go | 115 | ||||
-rw-r--r-- | server/guest/models.go | 2 | ||||
-rw-r--r-- | server/guest/store.go | 204 |
3 files changed, 191 insertions, 130 deletions
diff --git a/server/guest/handler.go b/server/guest/handler.go index 46b8a45..153a633 100644 --- a/server/guest/handler.go +++ b/server/guest/handler.go @@ -12,15 +12,15 @@ import ( ) var ( - guestRe = regexp.MustCompile(`^/guest/*$`) - guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`) + guestRegex = regexp.MustCompile(`^/guest/*$`) + guestIDRegex = regexp.MustCompile(`^/guest/([0-9]+)$`) ) type GuestHandler struct { - store guestStore + guestStore GuestStore } -type guestStore interface { +type GuestStore interface { Find(credentials Credentials) (Guest, error) Get() ([]Guest, error) Add(guest Guest) error @@ -33,9 +33,9 @@ type appError struct { Code int } -func NewGuestHandler(s guestStore) *GuestHandler { +func NewGuestHandler(guestStore GuestStore) *GuestHandler { return &GuestHandler{ - store: s, + guestStore, } } @@ -45,11 +45,11 @@ func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, responseWriter.WriteHeader(http.StatusOK) case request.Method == http.MethodPost && request.URL.Path == "/guest/login": guestHandler.handleLogIn(responseWriter, request) - case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): + case request.Method == http.MethodPut && guestIDRegex.MatchString(request.URL.Path): guestHandler.handlePut(responseWriter, request) - case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): + case request.Method == http.MethodGet && guestRegex.MatchString(request.URL.Path): guestHandler.handleGet(responseWriter, request) - case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): + case request.Method == http.MethodPost && guestIDRegex.MatchString(request.URL.Path): guestHandler.handlePost(responseWriter, request) default: responseWriter.WriteHeader(http.StatusNotFound) @@ -57,7 +57,7 @@ func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, } func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { - token, err := guestHandler.logInGuest(request) + token, err := guestHandler.logIn(request) if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { @@ -90,28 +90,28 @@ func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, } } -func (guestHandler *GuestHandler) logInGuest(request *http.Request) ([]byte, *appError) { +func (guestHandler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) { credentials, err := guestHandler.decodeCredentials(request) if err != nil { - return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest} + return []byte{}, &appError{err, "failed to unmarshal credentials", http.StatusBadRequest} } - guest, err := guestHandler.store.Find(credentials) + guest, err := guestHandler.guestStore.Find(credentials) if err != nil { - return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized} + return []byte{}, &appError{err, "guest not found", http.StatusUnauthorized} } expirationTime := guestHandler.setExpirationTime() claims := guestHandler.createClaims(credentials, expirationTime) - key, err := guestHandler.readKey() + key, err := guestHandler.readGuestKey() if err != nil { - return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} } token, err := guestHandler.createToken(claims, key) if err != nil { - return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError} } jsonBytes, err := guestHandler.marshalResponse(guest, token) if err != nil { - return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError} } return jsonBytes, nil } @@ -136,9 +136,13 @@ func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirati } } -func (guestHandler *GuestHandler) readKey() ([]byte, error) { +func (guestHandler *GuestHandler) readGuestKey() ([]byte, error) { // TODO: use properties file - return os.ReadFile("C:\\Users\\mhunt\\skey.pem") + return os.ReadFile("C:\\Users\\mhunt\\guest.pem") +} + +func (guestHandler *GuestHandler) readAdminKey() ([]byte, error) { + return os.ReadFile("C:\\Users\\mhunt\\admin.pem") } func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { @@ -159,38 +163,38 @@ func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token } func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError { - if err := guestHandler.validateToken(request); err != nil { + guestKey, err := guestHandler.readGuestKey() + if err != nil { + return &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := guestHandler.validateToken(request, guestKey); err != nil { return err } - if guestHandler.findId(request) { - return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound} + if guestHandler.findID(request) { + return &appError{errors.New("id not found"), "id not found", http.StatusNotFound} } guest, err := guestHandler.decodeGuest(request) if err != nil { - return &appError{err, "Invalid guest", http.StatusBadRequest} + return &appError{err, "invalid guest", http.StatusBadRequest} } - if err := guestHandler.store.Update(guest); err != nil { - return &appError{err, "Failed to update guest", http.StatusInternalServerError} + if err := guestHandler.guestStore.Update(guest); err != nil { + return &appError{err, "failed to update guest", http.StatusInternalServerError} } return nil } -func (guestHandler *GuestHandler) validateToken(request *http.Request) *appError { +func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byte) *appError { authorizationHeader := guestHandler.getToken(request) - claims := guestHandler.initializeClaims() - key, err := guestHandler.readKey() - if err != nil { - return &appError{err, "Failed to read secret key", http.StatusInternalServerError} - } + claims := guestHandler.newClaims() token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) if err != nil { if err == jwt.ErrSignatureInvalid { - return &appError{err, "Invalid signature", http.StatusUnauthorized} + return &appError{err, "invalid signature", http.StatusUnauthorized} } - return &appError{err, "Failed to parse claims", http.StatusBadRequest} + return &appError{err, "failed to parse claims", http.StatusBadRequest} } if !token.Valid { - return &appError{err, "Invalid token", http.StatusUnauthorized} + return &appError{err, "invalid token", http.StatusUnauthorized} } return nil } @@ -199,7 +203,7 @@ func (guestHandler *GuestHandler) getToken(request *http.Request) string { return request.Header.Get("Authorization") } -func (guestHandler *GuestHandler) initializeClaims() *Claims { +func (guestHandler *GuestHandler) newClaims() *Claims { return &Claims{} } @@ -209,8 +213,8 @@ func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, }) } -func (guestHandler *GuestHandler) findId(request *http.Request) bool { - matches := guestIdRe.FindStringSubmatch(request.URL.Path) +func (guestHandler *GuestHandler) findID(request *http.Request) bool { + matches := guestIDRegex.FindStringSubmatch(request.URL.Path) return len(matches) < 2 } @@ -222,46 +226,53 @@ func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, err } func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) { - // TODO: check with admin token - if err := guestHandler.validateToken(request); err != nil { + adminKey, err := guestHandler.readAdminKey() + if err != nil { + return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := guestHandler.validateToken(request, adminKey); err != nil { return []byte{}, err } - guests, err := guestHandler.store.Get() + guests, err := guestHandler.guestStore.Get() if err != nil { - return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError} } jsonBytes, err := json.Marshal(guests) if err != nil { - return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to marshal guests", http.StatusInternalServerError} } return jsonBytes, nil } func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError { - if err := guestHandler.validateToken(request); err != nil { + adminKey, err := guestHandler.readAdminKey() + if err != nil { + return &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := guestHandler.validateToken(request, adminKey); err != nil { return err } guest, err := guestHandler.decodeGuest(request) if err != nil { - return &appError{err, "Invalid guest", http.StatusBadRequest} + return &appError{err, "invalid guest", http.StatusBadRequest} } - guests, err := guestHandler.store.Get() + guests, err := guestHandler.guestStore.Get() if err != nil { - return &appError{err, "Failed to get guests", http.StatusInternalServerError} + return &appError{err, "failed to get guests", http.StatusInternalServerError} } if err := guestHandler.checkExistingGuests(guests, guest); err != nil { - return &appError{err, "ID already exists", http.StatusConflict} + return &appError{err, "id already exists", http.StatusConflict} } - if err := guestHandler.store.Add(guest); err != nil { - return &appError{err, "Failed to add guest", http.StatusInternalServerError} + if err := guestHandler.guestStore.Add(guest); err != nil { + return &appError{err, "failed to add guest", http.StatusInternalServerError} } return nil } func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error { for _, guest := range guests { - if guest.Id == newGuest.Id { - return errors.New("ID already exists") + if guest.ID == newGuest.ID { + return errors.New("id already exists") } } return nil diff --git a/server/guest/models.go b/server/guest/models.go index aa2c56f..5cd1b8b 100644 --- a/server/guest/models.go +++ b/server/guest/models.go @@ -3,7 +3,7 @@ package guest import "github.com/golang-jwt/jwt/v5" type Guest struct { - Id int `json:"id"` + ID int `json:"id"` FirstName string `json:"firstName"` LastName string `json:"lastName"` Attendance string `json:"attendance"` diff --git a/server/guest/store.go b/server/guest/store.go index 1a07161..a5b9374 100644 --- a/server/guest/store.go +++ b/server/guest/store.go @@ -4,135 +4,185 @@ import ( "context" "errors" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) -type MemStore struct { - db *pgxpool.Pool +type Store struct { + database *pgxpool.Pool } -func NewMemStore(db *pgxpool.Pool) *MemStore { - return &MemStore{ - db, +func NewStore(database *pgxpool.Pool) *Store { + return &Store{ + database, } } -func (m MemStore) Find(creds Credentials) (Guest, error) { - rows, err := m.db.Query(context.Background(), "select * from guest") - var guest Guest +func (store Store) Find(credentials Credentials) (Guest, error) { + guestRows, err := store.database.Query(context.Background(), + "select * from guest") if err != nil { - return guest, err + return Guest{}, err } - defer rows.Close() + defer guestRows.Close() + guest, found := createGuest(credentials, guestRows) - found := false - for rows.Next() { - err := rows.Scan(&guest.Id, &guest.FirstName, &guest.LastName, &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) - if err != nil { - return guest, err - } - if guest.FirstName == creds.FirstName && guest.LastName == creds.LastName { - found = true - break - } + partyRows, err := store.database.Query(context.Background(), + "select * from party") + if err != nil { + return Guest{}, err } + defer partyRows.Close() - rows, err = m.db.Query(context.Background(), "select * from party") + guest, err = addParty(guest, partyRows) if err != nil { - return guest, err + return Guest{}, err } - defer rows.Close() - for rows.Next() { - var guestId int - var partyGuest PartyGuest - err := rows.Scan(&guestId, &partyGuest.FirstName, &partyGuest.LastName) + if found { + return guest, nil + } + return Guest{}, errors.New("guest not found") +} + +func createGuest(credentials Credentials, guestRows pgx.Rows) (Guest, bool) { + var guest Guest + for guestRows.Next() { + err := guestRows.Scan(&guest.ID, &guest.FirstName, &guest.LastName, + &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) if err != nil { - return guest, err + return Guest{}, false } - if guestId == guest.Id { - guest.PartyList = append(guest.PartyList, partyGuest) + if guest.FirstName == credentials.FirstName && + guest.LastName == credentials.LastName { + return guest, true } } + return Guest{}, false +} - if found { - return guest, nil +func addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) { + guestWithParty := guestWithoutParty + for partyRows.Next() { + var guestID int + var partyGuest PartyGuest + err := partyRows.Scan(&guestID, &partyGuest.FirstName, &partyGuest.LastName) + if err != nil { + return Guest{}, err + } + if guestID == guestWithParty.ID { + guestWithParty.PartyList = append(guestWithParty.PartyList, partyGuest) + } } - return guest, errors.New("Guest not found") + return guestWithParty, nil } -func (m MemStore) Get() ([]Guest, error) { - rows, err := m.db.Query(context.Background(), "select * from guest") +func (store Store) Get() ([]Guest, error) { + guestRows, err := store.database.Query(context.Background(), + "select * from guest") if err != nil { return nil, err } - defer rows.Close() + defer guestRows.Close() - guestSlice := []Guest{} - for rows.Next() { - var guest Guest - err := rows.Scan(&guest.Id, &guest.FirstName, &guest.LastName, &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) - if err != nil { - return guestSlice, err - } - guestSlice = append(guestSlice, guest) + guestsWithoutParty, err := store.createGuestSlice(guestRows) + if err != nil { + return []Guest{}, err } - rows, err = m.db.Query(context.Background(), "select * from party") + partyRows, err := store.database.Query(context.Background(), + "select * from party") if err != nil { - return guestSlice, err + return []Guest{}, err } - defer rows.Close() + defer partyRows.Close() - for rows.Next() { - var guestId int + guestsWithParty, err := addPartySlice(guestsWithoutParty, partyRows) + if err != nil { + return []Guest{}, err + } + return guestsWithParty, nil +} + +func (store Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) { + guests := []Guest{} + for guestRows.Next() { + var guest Guest + err := guestRows.Scan(&guest.ID, &guest.FirstName, &guest.LastName, + &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) + if err != nil { + return []Guest{}, err + } + guests = append(guests, guest) + } + return guests, nil +} + +func addPartySlice(guestsWithoutParty []Guest, + partyRows pgx.Rows) ([]Guest, error) { + guestsWithParty := guestsWithoutParty + for partyRows.Next() { + var guestID int var partyGuest PartyGuest - err := rows.Scan(&guestId, &partyGuest.FirstName, &partyGuest.LastName) + err := partyRows.Scan(&guestID, &partyGuest.FirstName, &partyGuest.LastName) if err != nil { - return guestSlice, err + return []Guest{}, err } - for i, g := range guestSlice { - if guestId == g.Id { - guestSlice[i].PartyList = append(g.PartyList, partyGuest) + for i, guest := range guestsWithParty { + if guestID == guest.ID { + guestsWithParty[i].PartyList = append(guest.PartyList, partyGuest) } } } - return guestSlice, nil + return guestsWithParty, nil } -func (m MemStore) Add(guest Guest) error { - statement := "insert into guest (id, first_name, last_name, attendance, email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)" - _, err := m.db.Exec(context.Background(), statement, guest.Id, guest.FirstName, guest.LastName, guest.Attendance, guest.Email, guest.Message, guest.PartySize) - if err != nil { +func (store Store) Add(guest Guest) error { + if err := store.insertGuest(guest); err != nil { return err } + return store.insertParty(guest) +} - statement = "insert into party (guest_id, first_name, last_name) values ($1, $2, $3)" - for _, pg := range guest.PartyList { - _, err = m.db.Exec(context.Background(), statement, guest.Id, pg.FirstName, pg.LastName) - if err != nil { - return err - } - } - return nil +func (store Store) insertGuest(guest Guest) error { + statement := `insert into guest (id, first_name, last_name, attendance, + email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)` + _, err := store.database.Exec(context.Background(), statement, guest.ID, + guest.FirstName, guest.LastName, guest.Attendance, guest.Email, + guest.Message, guest.PartySize) + return err } -func (m MemStore) Update(guest Guest) error { - statement := "update guest set attendance = $1, email = $2, message = $3, party_size = $4 where id = $5" - _, err := m.db.Exec(context.Background(), statement, guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.Id) - if err != nil { +func (store Store) Update(guest Guest) error { + if err := store.updateGuest(guest); err != nil { return err } - - statement = "delete from party where guest_id = $1" - _, err = m.db.Exec(context.Background(), statement, guest.Id) - if err != nil { + if err := store.deleteOldParty(guest.ID); err != nil { return err } + return store.insertParty(guest) +} + +func (store Store) updateGuest(guest Guest) error { + statement := `update guest set attendance = $1, email = $2, message = $3, + party_size = $4 where id = $5` + _, err := store.database.Exec(context.Background(), statement, + guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.ID) + return err +} + +func (store Store) deleteOldParty(guestID int) error { + statement := "delete from party where guest_id = $1" + _, err := store.database.Exec(context.Background(), statement, guestID) + return err +} - statement = "insert into party (guest_id, first_name, last_name) values ($1, $2, $3)" +func (store Store) insertParty(guest Guest) error { + statement := `insert into party (guest_id, first_name, last_name) + values ($1, $2, $3)` for _, pg := range guest.PartyList { - _, err = m.db.Exec(context.Background(), statement, guest.Id, pg.FirstName, pg.LastName) + _, err := store.database.Exec(context.Background(), statement, guest.ID, + pg.FirstName, pg.LastName) if err != nil { return err } |