diff options
author | Michael Hunteman <michael@huntm.net> | 2024-08-25 12:44:32 -0700 |
---|---|---|
committer | Michael Hunteman <michael@huntm.net> | 2024-08-25 12:44:32 -0700 |
commit | 096a08708e2310becba56a237ef63b5cf6e3c4c4 (patch) | |
tree | 2924f9aecdcf035599558552cfdb20c2cc18f7d1 /server | |
parent | 6aee47e76d7e25206b3778aeebcc341d7b705035 (diff) |
Add admin dashboard
Diffstat (limited to 'server')
-rw-r--r-- | server/admin/handler.go | 119 | ||||
-rw-r--r-- | server/admin/models.go | 21 | ||||
-rw-r--r-- | server/admin/store.go | 49 | ||||
-rw-r--r-- | server/cmd/main.go | 58 | ||||
-rw-r--r-- | server/go.mod | 2 | ||||
-rw-r--r-- | server/go.sum | 4 | ||||
-rw-r--r-- | server/guest/handler.go | 115 | ||||
-rw-r--r-- | server/guest/models.go | 2 | ||||
-rw-r--r-- | server/guest/store.go | 204 | ||||
-rw-r--r-- | server/schema.sql | 8 | ||||
-rw-r--r-- | server/test/guest_test.go | 89 |
11 files changed, 488 insertions, 183 deletions
diff --git a/server/admin/handler.go b/server/admin/handler.go new file mode 100644 index 0000000..5fd1fee --- /dev/null +++ b/server/admin/handler.go @@ -0,0 +1,119 @@ +package admin + +import ( + "encoding/json" + "net/http" + "os" + "time" + + "git.huntm.net/wedding/server/guest" + "github.com/golang-jwt/jwt/v5" +) + +type AdminHandler struct { + adminStore adminStore + guestStore guest.GuestStore +} + +type adminStore interface { + Find(admin Admin) (Admin, error) +} + +type appError struct { + Error error + Message string + Code int +} + +func NewAdminHandler(adminStore adminStore, guestStore guest.GuestStore) *AdminHandler { + return &AdminHandler{adminStore, guestStore} +} + +func (adminHandler *AdminHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + switch { + case request.Method == http.MethodOptions: + responseWriter.WriteHeader(http.StatusOK) + case request.Method == http.MethodPost && request.URL.Path == "/admin/login": + adminHandler.handleLogIn(responseWriter, request) + default: + responseWriter.WriteHeader(http.StatusNotFound) + } +} + +func (adminHandler *AdminHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { + token, err := adminHandler.logIn(request) + if err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.Write(token) + } +} + +func (adminHandler *AdminHandler) logIn(request *http.Request) ([]byte, *appError) { + requestAdmin, err := adminHandler.decodeCredentials(request) + if err != nil { + return []byte{}, &appError{err, "failed to unmarshal request", http.StatusBadRequest} + } + _, err = adminHandler.adminStore.Find(requestAdmin) + if err != nil { + return []byte{}, &appError{err, "admin not found", http.StatusUnauthorized} + } + expirationTime := adminHandler.setExpirationTime() + claims := adminHandler.createClaims(requestAdmin, expirationTime) + key, err := adminHandler.readKey() + if err != nil { + return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + token, err := adminHandler.createToken(claims, key) + if err != nil { + return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError} + } + guests, err := adminHandler.guestStore.Get() + if err != nil { + return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError} + } + jsonBytes, err := adminHandler.marshalResponse(guests, token) + if err != nil { + return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError} + } + return jsonBytes, nil +} + +func (adminHandler *AdminHandler) decodeCredentials(request *http.Request) (Admin, error) { + var admin Admin + err := json.NewDecoder(request.Body).Decode(&admin) + defer request.Body.Close() + return admin, err +} + +func (adminHandler *AdminHandler) setExpirationTime() time.Time { + return time.Now().Add(15 * time.Minute) +} + +func (adminHandler *AdminHandler) createClaims(admin Admin, expirationTime time.Time) *Claims { + return &Claims{ + admin, + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expirationTime), + }, + } +} + +func (adminHandler *AdminHandler) readKey() ([]byte, error) { + // TODO: use properties file + return os.ReadFile("C:\\Users\\mhunt\\admin.pem") +} + +func (adminHandler *AdminHandler) createToken(claims *Claims, key []byte) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(key) +} + +func (adminHandler *AdminHandler) marshalResponse(guests []guest.Guest, token string) ([]byte, error) { + loginResponse := adminHandler.createLoginResponse(guests, token) + return json.Marshal(loginResponse) +} + +func (adminHandler *AdminHandler) createLoginResponse(guests []guest.Guest, token string) *LoginResponse { + return &LoginResponse{guests, token} +} diff --git a/server/admin/models.go b/server/admin/models.go new file mode 100644 index 0000000..d9b8232 --- /dev/null +++ b/server/admin/models.go @@ -0,0 +1,21 @@ +package admin + +import ( + "git.huntm.net/wedding/server/guest" + "github.com/golang-jwt/jwt/v5" +) + +type Admin struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type Claims struct { + Admin Admin `json:"admin"` + jwt.RegisteredClaims +} + +type LoginResponse struct { + Guests []guest.Guest `json:"guests"` + Token string `json:"token"` +} diff --git a/server/admin/store.go b/server/admin/store.go new file mode 100644 index 0000000..437e6af --- /dev/null +++ b/server/admin/store.go @@ -0,0 +1,49 @@ +package admin + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +type Store struct { + database *pgxpool.Pool +} + +func NewStore(database *pgxpool.Pool) *Store { + return &Store{ + database, + } +} + +func (store Store) Find(requestAdmin Admin) (Admin, error) { + adminRows, err := store.database.Query(context.Background(), + "select * from admin") + if err != nil { + return Admin{}, err + } + defer adminRows.Close() + admin, found := createAdmin(requestAdmin, adminRows) + + if found { + return admin, nil + } + return Admin{}, errors.New("admin not found") +} + +func createAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) { + var databaseAdmin Admin + for adminRows.Next() { + err := adminRows.Scan(&databaseAdmin.Username, &databaseAdmin.Password) + if err != nil { + return Admin{}, false + } + if databaseAdmin.Username == requestAdmin.Username && + databaseAdmin.Password == requestAdmin.Password { + return databaseAdmin, true + } + } + return Admin{}, false +} diff --git a/server/cmd/main.go b/server/cmd/main.go index 439ed0b..38ccf37 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -10,53 +10,65 @@ import ( "github.com/jackc/pgx/v5/pgxpool" + "git.huntm.net/wedding/server/admin" "git.huntm.net/wedding/server/guest" ) var ( user = os.Getenv("USER") - pass = os.Getenv("PASS") + password = os.Getenv("PASS") host = "localhost" port = "5432" database = "postgres" ) func main() { - db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database)) + databasePool, err := pgxpool.New(context.Background(), + fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, host, port, database)) if err != nil { log.Fatal(err) } - defer db.Close() + defer databasePool.Close() - store := guest.NewMemStore(db) - guestHandler := guest.NewGuestHandler(store) + guestStore := guest.NewStore(databasePool) + guestHandler := guest.NewGuestHandler(guestStore) + adminStore := admin.NewStore(databasePool) + adminHandler := admin.NewAdminHandler(adminStore, guestStore) mux := http.NewServeMux() mux.Handle("/guest/", guestHandler) - log.Fatal(http.ListenAndServe(":8080", enableCORS(mux))) + mux.Handle("/admin/", adminHandler) + log.Fatal(http.ListenAndServe(":8080", serveHTTP(mux))) } -func enableCORS(handler http.Handler) http.Handler { - allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"} - allowedMethods := []string{"OPTIONS", "POST", "PUT"} - return serveHTTP(handler, allowedOrigins, allowedMethods) +func serveHTTP(handler http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, + request *http.Request) { + writeMethods(responseWriter, request) + writeOrigins(responseWriter, request) + writeHeaders(responseWriter) + handler.ServeHTTP(responseWriter, request) + }) } -func serveHTTP(handler http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler { - return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { - method := request.Header.Get("Access-Control-Request-Method") - if isPreflight(request) && slices.Contains(allowedMethods, method) { - responseWriter.Header().Add("Access-Control-Allow-Methods", method) - } +func writeMethods(responseWriter http.ResponseWriter, request *http.Request) { + allowedMethods := []string{"OPTIONS", "POST", "PUT"} + method := request.Header.Get("Access-Control-Request-Method") + if isPreflight(request) && slices.Contains(allowedMethods, method) { + responseWriter.Header().Add("Access-Control-Allow-Methods", method) + } +} - origin := request.Header.Get("Origin") - if slices.Contains(allowedOrigins, origin) { - responseWriter.Header().Add("Access-Control-Allow-Origin", origin) - } +func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) { + allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.18:5173"} + origin := request.Header.Get("Origin") + if slices.Contains(allowedOrigins, origin) { + responseWriter.Header().Add("Access-Control-Allow-Origin", origin) + } +} - responseWriter.Header().Add("Access-Control-Allow-Headers", "*") - handler.ServeHTTP(responseWriter, request) - }) +func writeHeaders(responseWriter http.ResponseWriter) { + responseWriter.Header().Add("Access-Control-Allow-Headers", "*") } func isPreflight(request *http.Request) bool { diff --git a/server/go.mod b/server/go.mod index 49af49e..8ae7015 100644 --- a/server/go.mod +++ b/server/go.mod @@ -4,7 +4,7 @@ go 1.22.2 require ( github.com/golang-jwt/jwt/v5 v5.2.1 - github.com/jackc/pgx/v5 v5.5.5 + github.com/jackc/pgx/v5 v5.6.0 ) require ( diff --git a/server/go.sum b/server/go.sum index b93b74b..590fe16 100644 --- a/server/go.sum +++ b/server/go.sum @@ -7,8 +7,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/server/guest/handler.go b/server/guest/handler.go index 46b8a45..153a633 100644 --- a/server/guest/handler.go +++ b/server/guest/handler.go @@ -12,15 +12,15 @@ import ( ) var ( - guestRe = regexp.MustCompile(`^/guest/*$`) - guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`) + guestRegex = regexp.MustCompile(`^/guest/*$`) + guestIDRegex = regexp.MustCompile(`^/guest/([0-9]+)$`) ) type GuestHandler struct { - store guestStore + guestStore GuestStore } -type guestStore interface { +type GuestStore interface { Find(credentials Credentials) (Guest, error) Get() ([]Guest, error) Add(guest Guest) error @@ -33,9 +33,9 @@ type appError struct { Code int } -func NewGuestHandler(s guestStore) *GuestHandler { +func NewGuestHandler(guestStore GuestStore) *GuestHandler { return &GuestHandler{ - store: s, + guestStore, } } @@ -45,11 +45,11 @@ func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, responseWriter.WriteHeader(http.StatusOK) case request.Method == http.MethodPost && request.URL.Path == "/guest/login": guestHandler.handleLogIn(responseWriter, request) - case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): + case request.Method == http.MethodPut && guestIDRegex.MatchString(request.URL.Path): guestHandler.handlePut(responseWriter, request) - case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): + case request.Method == http.MethodGet && guestRegex.MatchString(request.URL.Path): guestHandler.handleGet(responseWriter, request) - case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): + case request.Method == http.MethodPost && guestIDRegex.MatchString(request.URL.Path): guestHandler.handlePost(responseWriter, request) default: responseWriter.WriteHeader(http.StatusNotFound) @@ -57,7 +57,7 @@ func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, } func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { - token, err := guestHandler.logInGuest(request) + token, err := guestHandler.logIn(request) if err != nil { http.Error(responseWriter, err.Message, err.Code) } else { @@ -90,28 +90,28 @@ func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, } } -func (guestHandler *GuestHandler) logInGuest(request *http.Request) ([]byte, *appError) { +func (guestHandler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) { credentials, err := guestHandler.decodeCredentials(request) if err != nil { - return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest} + return []byte{}, &appError{err, "failed to unmarshal credentials", http.StatusBadRequest} } - guest, err := guestHandler.store.Find(credentials) + guest, err := guestHandler.guestStore.Find(credentials) if err != nil { - return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized} + return []byte{}, &appError{err, "guest not found", http.StatusUnauthorized} } expirationTime := guestHandler.setExpirationTime() claims := guestHandler.createClaims(credentials, expirationTime) - key, err := guestHandler.readKey() + key, err := guestHandler.readGuestKey() if err != nil { - return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} } token, err := guestHandler.createToken(claims, key) if err != nil { - return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError} } jsonBytes, err := guestHandler.marshalResponse(guest, token) if err != nil { - return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError} } return jsonBytes, nil } @@ -136,9 +136,13 @@ func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirati } } -func (guestHandler *GuestHandler) readKey() ([]byte, error) { +func (guestHandler *GuestHandler) readGuestKey() ([]byte, error) { // TODO: use properties file - return os.ReadFile("C:\\Users\\mhunt\\skey.pem") + return os.ReadFile("C:\\Users\\mhunt\\guest.pem") +} + +func (guestHandler *GuestHandler) readAdminKey() ([]byte, error) { + return os.ReadFile("C:\\Users\\mhunt\\admin.pem") } func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { @@ -159,38 +163,38 @@ func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token } func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError { - if err := guestHandler.validateToken(request); err != nil { + guestKey, err := guestHandler.readGuestKey() + if err != nil { + return &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := guestHandler.validateToken(request, guestKey); err != nil { return err } - if guestHandler.findId(request) { - return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound} + if guestHandler.findID(request) { + return &appError{errors.New("id not found"), "id not found", http.StatusNotFound} } guest, err := guestHandler.decodeGuest(request) if err != nil { - return &appError{err, "Invalid guest", http.StatusBadRequest} + return &appError{err, "invalid guest", http.StatusBadRequest} } - if err := guestHandler.store.Update(guest); err != nil { - return &appError{err, "Failed to update guest", http.StatusInternalServerError} + if err := guestHandler.guestStore.Update(guest); err != nil { + return &appError{err, "failed to update guest", http.StatusInternalServerError} } return nil } -func (guestHandler *GuestHandler) validateToken(request *http.Request) *appError { +func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byte) *appError { authorizationHeader := guestHandler.getToken(request) - claims := guestHandler.initializeClaims() - key, err := guestHandler.readKey() - if err != nil { - return &appError{err, "Failed to read secret key", http.StatusInternalServerError} - } + claims := guestHandler.newClaims() token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) if err != nil { if err == jwt.ErrSignatureInvalid { - return &appError{err, "Invalid signature", http.StatusUnauthorized} + return &appError{err, "invalid signature", http.StatusUnauthorized} } - return &appError{err, "Failed to parse claims", http.StatusBadRequest} + return &appError{err, "failed to parse claims", http.StatusBadRequest} } if !token.Valid { - return &appError{err, "Invalid token", http.StatusUnauthorized} + return &appError{err, "invalid token", http.StatusUnauthorized} } return nil } @@ -199,7 +203,7 @@ func (guestHandler *GuestHandler) getToken(request *http.Request) string { return request.Header.Get("Authorization") } -func (guestHandler *GuestHandler) initializeClaims() *Claims { +func (guestHandler *GuestHandler) newClaims() *Claims { return &Claims{} } @@ -209,8 +213,8 @@ func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, }) } -func (guestHandler *GuestHandler) findId(request *http.Request) bool { - matches := guestIdRe.FindStringSubmatch(request.URL.Path) +func (guestHandler *GuestHandler) findID(request *http.Request) bool { + matches := guestIDRegex.FindStringSubmatch(request.URL.Path) return len(matches) < 2 } @@ -222,46 +226,53 @@ func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, err } func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) { - // TODO: check with admin token - if err := guestHandler.validateToken(request); err != nil { + adminKey, err := guestHandler.readAdminKey() + if err != nil { + return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := guestHandler.validateToken(request, adminKey); err != nil { return []byte{}, err } - guests, err := guestHandler.store.Get() + guests, err := guestHandler.guestStore.Get() if err != nil { - return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError} } jsonBytes, err := json.Marshal(guests) if err != nil { - return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError} + return []byte{}, &appError{err, "failed to marshal guests", http.StatusInternalServerError} } return jsonBytes, nil } func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError { - if err := guestHandler.validateToken(request); err != nil { + adminKey, err := guestHandler.readAdminKey() + if err != nil { + return &appError{err, "failed to read secret key", http.StatusInternalServerError} + } + if err := guestHandler.validateToken(request, adminKey); err != nil { return err } guest, err := guestHandler.decodeGuest(request) if err != nil { - return &appError{err, "Invalid guest", http.StatusBadRequest} + return &appError{err, "invalid guest", http.StatusBadRequest} } - guests, err := guestHandler.store.Get() + guests, err := guestHandler.guestStore.Get() if err != nil { - return &appError{err, "Failed to get guests", http.StatusInternalServerError} + return &appError{err, "failed to get guests", http.StatusInternalServerError} } if err := guestHandler.checkExistingGuests(guests, guest); err != nil { - return &appError{err, "ID already exists", http.StatusConflict} + return &appError{err, "id already exists", http.StatusConflict} } - if err := guestHandler.store.Add(guest); err != nil { - return &appError{err, "Failed to add guest", http.StatusInternalServerError} + if err := guestHandler.guestStore.Add(guest); err != nil { + return &appError{err, "failed to add guest", http.StatusInternalServerError} } return nil } func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error { for _, guest := range guests { - if guest.Id == newGuest.Id { - return errors.New("ID already exists") + if guest.ID == newGuest.ID { + return errors.New("id already exists") } } return nil diff --git a/server/guest/models.go b/server/guest/models.go index aa2c56f..5cd1b8b 100644 --- a/server/guest/models.go +++ b/server/guest/models.go @@ -3,7 +3,7 @@ package guest import "github.com/golang-jwt/jwt/v5" type Guest struct { - Id int `json:"id"` + ID int `json:"id"` FirstName string `json:"firstName"` LastName string `json:"lastName"` Attendance string `json:"attendance"` diff --git a/server/guest/store.go b/server/guest/store.go index 1a07161..a5b9374 100644 --- a/server/guest/store.go +++ b/server/guest/store.go @@ -4,135 +4,185 @@ import ( "context" "errors" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) -type MemStore struct { - db *pgxpool.Pool +type Store struct { + database *pgxpool.Pool } -func NewMemStore(db *pgxpool.Pool) *MemStore { - return &MemStore{ - db, +func NewStore(database *pgxpool.Pool) *Store { + return &Store{ + database, } } -func (m MemStore) Find(creds Credentials) (Guest, error) { - rows, err := m.db.Query(context.Background(), "select * from guest") - var guest Guest +func (store Store) Find(credentials Credentials) (Guest, error) { + guestRows, err := store.database.Query(context.Background(), + "select * from guest") if err != nil { - return guest, err + return Guest{}, err } - defer rows.Close() + defer guestRows.Close() + guest, found := createGuest(credentials, guestRows) - found := false - for rows.Next() { - err := rows.Scan(&guest.Id, &guest.FirstName, &guest.LastName, &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) - if err != nil { - return guest, err - } - if guest.FirstName == creds.FirstName && guest.LastName == creds.LastName { - found = true - break - } + partyRows, err := store.database.Query(context.Background(), + "select * from party") + if err != nil { + return Guest{}, err } + defer partyRows.Close() - rows, err = m.db.Query(context.Background(), "select * from party") + guest, err = addParty(guest, partyRows) if err != nil { - return guest, err + return Guest{}, err } - defer rows.Close() - for rows.Next() { - var guestId int - var partyGuest PartyGuest - err := rows.Scan(&guestId, &partyGuest.FirstName, &partyGuest.LastName) + if found { + return guest, nil + } + return Guest{}, errors.New("guest not found") +} + +func createGuest(credentials Credentials, guestRows pgx.Rows) (Guest, bool) { + var guest Guest + for guestRows.Next() { + err := guestRows.Scan(&guest.ID, &guest.FirstName, &guest.LastName, + &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) if err != nil { - return guest, err + return Guest{}, false } - if guestId == guest.Id { - guest.PartyList = append(guest.PartyList, partyGuest) + if guest.FirstName == credentials.FirstName && + guest.LastName == credentials.LastName { + return guest, true } } + return Guest{}, false +} - if found { - return guest, nil +func addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) { + guestWithParty := guestWithoutParty + for partyRows.Next() { + var guestID int + var partyGuest PartyGuest + err := partyRows.Scan(&guestID, &partyGuest.FirstName, &partyGuest.LastName) + if err != nil { + return Guest{}, err + } + if guestID == guestWithParty.ID { + guestWithParty.PartyList = append(guestWithParty.PartyList, partyGuest) + } } - return guest, errors.New("Guest not found") + return guestWithParty, nil } -func (m MemStore) Get() ([]Guest, error) { - rows, err := m.db.Query(context.Background(), "select * from guest") +func (store Store) Get() ([]Guest, error) { + guestRows, err := store.database.Query(context.Background(), + "select * from guest") if err != nil { return nil, err } - defer rows.Close() + defer guestRows.Close() - guestSlice := []Guest{} - for rows.Next() { - var guest Guest - err := rows.Scan(&guest.Id, &guest.FirstName, &guest.LastName, &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) - if err != nil { - return guestSlice, err - } - guestSlice = append(guestSlice, guest) + guestsWithoutParty, err := store.createGuestSlice(guestRows) + if err != nil { + return []Guest{}, err } - rows, err = m.db.Query(context.Background(), "select * from party") + partyRows, err := store.database.Query(context.Background(), + "select * from party") if err != nil { - return guestSlice, err + return []Guest{}, err } - defer rows.Close() + defer partyRows.Close() - for rows.Next() { - var guestId int + guestsWithParty, err := addPartySlice(guestsWithoutParty, partyRows) + if err != nil { + return []Guest{}, err + } + return guestsWithParty, nil +} + +func (store Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) { + guests := []Guest{} + for guestRows.Next() { + var guest Guest + err := guestRows.Scan(&guest.ID, &guest.FirstName, &guest.LastName, + &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize) + if err != nil { + return []Guest{}, err + } + guests = append(guests, guest) + } + return guests, nil +} + +func addPartySlice(guestsWithoutParty []Guest, + partyRows pgx.Rows) ([]Guest, error) { + guestsWithParty := guestsWithoutParty + for partyRows.Next() { + var guestID int var partyGuest PartyGuest - err := rows.Scan(&guestId, &partyGuest.FirstName, &partyGuest.LastName) + err := partyRows.Scan(&guestID, &partyGuest.FirstName, &partyGuest.LastName) if err != nil { - return guestSlice, err + return []Guest{}, err } - for i, g := range guestSlice { - if guestId == g.Id { - guestSlice[i].PartyList = append(g.PartyList, partyGuest) + for i, guest := range guestsWithParty { + if guestID == guest.ID { + guestsWithParty[i].PartyList = append(guest.PartyList, partyGuest) } } } - return guestSlice, nil + return guestsWithParty, nil } -func (m MemStore) Add(guest Guest) error { - statement := "insert into guest (id, first_name, last_name, attendance, email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)" - _, err := m.db.Exec(context.Background(), statement, guest.Id, guest.FirstName, guest.LastName, guest.Attendance, guest.Email, guest.Message, guest.PartySize) - if err != nil { +func (store Store) Add(guest Guest) error { + if err := store.insertGuest(guest); err != nil { return err } + return store.insertParty(guest) +} - statement = "insert into party (guest_id, first_name, last_name) values ($1, $2, $3)" - for _, pg := range guest.PartyList { - _, err = m.db.Exec(context.Background(), statement, guest.Id, pg.FirstName, pg.LastName) - if err != nil { - return err - } - } - return nil +func (store Store) insertGuest(guest Guest) error { + statement := `insert into guest (id, first_name, last_name, attendance, + email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)` + _, err := store.database.Exec(context.Background(), statement, guest.ID, + guest.FirstName, guest.LastName, guest.Attendance, guest.Email, + guest.Message, guest.PartySize) + return err } -func (m MemStore) Update(guest Guest) error { - statement := "update guest set attendance = $1, email = $2, message = $3, party_size = $4 where id = $5" - _, err := m.db.Exec(context.Background(), statement, guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.Id) - if err != nil { +func (store Store) Update(guest Guest) error { + if err := store.updateGuest(guest); err != nil { return err } - - statement = "delete from party where guest_id = $1" - _, err = m.db.Exec(context.Background(), statement, guest.Id) - if err != nil { + if err := store.deleteOldParty(guest.ID); err != nil { return err } + return store.insertParty(guest) +} + +func (store Store) updateGuest(guest Guest) error { + statement := `update guest set attendance = $1, email = $2, message = $3, + party_size = $4 where id = $5` + _, err := store.database.Exec(context.Background(), statement, + guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.ID) + return err +} + +func (store Store) deleteOldParty(guestID int) error { + statement := "delete from party where guest_id = $1" + _, err := store.database.Exec(context.Background(), statement, guestID) + return err +} - statement = "insert into party (guest_id, first_name, last_name) values ($1, $2, $3)" +func (store Store) insertParty(guest Guest) error { + statement := `insert into party (guest_id, first_name, last_name) + values ($1, $2, $3)` for _, pg := range guest.PartyList { - _, err = m.db.Exec(context.Background(), statement, guest.Id, pg.FirstName, pg.LastName) + _, err := store.database.Exec(context.Background(), statement, guest.ID, + pg.FirstName, pg.LastName) if err != nil { return err } diff --git a/server/schema.sql b/server/schema.sql index fdc589a..809d8c9 100644 --- a/server/schema.sql +++ b/server/schema.sql @@ -12,4 +12,10 @@ create table party ( guest_id integer NOT NULL references guest(id) ON DELETE CASCADE, first_name varchar(64), last_name varchar(64) -);
\ No newline at end of file +); + +create table admin ( + id serial PRIMARY KEY, + username varchar(64), + password varchar(64) +)
\ No newline at end of file diff --git a/server/test/guest_test.go b/server/test/guest_test.go index c675ddb..206837a 100644 --- a/server/test/guest_test.go +++ b/server/test/guest_test.go @@ -11,53 +11,59 @@ import ( "strings" "testing" + "git.huntm.net/wedding/server/admin" "git.huntm.net/wedding/server/guest" "github.com/jackc/pgx/v5/pgxpool" ) var ( user = os.Getenv("USER") - pass = os.Getenv("PASS") + password = os.Getenv("PASS") host = "localhost" port = "5432" database = "postgres" ) func TestUpdateRSVP(test *testing.T) { - db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database)) + databasePool, err := pgxpool.New(context.Background(), + fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, host, port, database)) if err != nil { test.Error(err) } - defer db.Close() - store := guest.NewMemStore(db) - guestHandler := guest.NewGuestHandler(store) + defer databasePool.Close() + guestStore := guest.NewStore(databasePool) + guestHandler := guest.NewGuestHandler(guestStore) + adminStore := admin.NewStore(databasePool) + adminHandler := admin.NewAdminHandler(adminStore, guestStore) - login := logIn(guestHandler, test) - token := login.Token + guestLogin := logInGuest(guestHandler, test) + guestToken := guestLogin.Token + addPartyGuest(guestHandler, guestToken, test) - addPartyGuest(guestHandler, token, test) - guestSlice := getGuests(guestHandler, token, test) - guest := guestSlice[0] + adminLogin := logInAdmin(adminHandler, test) + adminToken := adminLogin.Token + guests := getGuests(guestHandler, adminToken, test) + guest := findGuest(guests) assertEquals(test, guest.Attendance, "accept") - assertEquals(test, guest.Email, "mhunteman@yahoo.com") + assertEquals(test, guest.Email, "mhunteman@cox.net") assertEquals(test, guest.Message, "We'll be there!") assertEquals(test, guest.PartySize, 2) assertEquals(test, guest.PartyList[0].FirstName, "Madison") assertEquals(test, guest.PartyList[0].LastName, "Rossitto") - deletePartyGuest(guestHandler, token, test) - guestSlice = getGuests(guestHandler, token, test) - guest = guestSlice[0] - assertEquals(test, guest.Attendance, "decline") + deletePartyGuest(guestHandler, guestToken, test) + guests = getGuests(guestHandler, adminToken, test) + guest = findGuest(guests) + assertEquals(test, guest.Attendance, "") assertEquals(test, guest.Email, "") assertEquals(test, guest.Message, "") assertEquals(test, guest.PartySize, 1) } -func logIn(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginResponse { +func logInGuest(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginResponse { response := httptest.NewRecorder() - loginRequest, err := http.NewRequest(http.MethodPost, "http://localhost:8080/guest/login", - strings.NewReader(getCredentials())) + loginRequest, err := http.NewRequest(http.MethodPost, + "http://localhost:8080/guest/login", strings.NewReader(getName())) if err != nil { test.Error(err) } @@ -73,7 +79,8 @@ func logIn(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginRespons func addPartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) { response := httptest.NewRecorder() guestWithParty := getGuestWithParty() - putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", strings.NewReader(guestWithParty)) + putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", + strings.NewReader(guestWithParty)) if err != nil { test.Error(err) } @@ -82,22 +89,39 @@ func addPartyGuest(guestHandler *guest.GuestHandler, token string, test *testing assertEquals(test, response.Result().StatusCode, 200) } +func logInAdmin(adminHandler *admin.AdminHandler, test *testing.T) admin.LoginResponse { + response := httptest.NewRecorder() + loginRequest, err := http.NewRequest(http.MethodPost, + "http://localhost:8080/admin/login", strings.NewReader(getCredentials())) + if err != nil { + test.Error(err) + } + adminHandler.ServeHTTP(response, loginRequest) + assertEquals(test, response.Result().StatusCode, 200) + var login admin.LoginResponse + if err = json.NewDecoder(response.Body).Decode(&login); err != nil { + log.Fatal(err) + } + return login +} + func getGuests(guestHandler *guest.GuestHandler, token string, test *testing.T) []guest.Guest { response := httptest.NewRecorder() getRequest, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/guest/", nil) getRequest.Header.Set("Authorization", token) guestHandler.ServeHTTP(response, getRequest) assertEquals(test, response.Result().StatusCode, 200) - var guestSlice []guest.Guest - if err := json.NewDecoder(response.Body).Decode(&guestSlice); err != nil { + var guests []guest.Guest + if err := json.NewDecoder(response.Body).Decode(&guests); err != nil { test.Error(err) } - return guestSlice + return guests } func deletePartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) { response := httptest.NewRecorder() - putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", strings.NewReader(getGuestWithoutParty())) + putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", + strings.NewReader(getGuestWithoutParty())) if err != nil { test.Error(err) } @@ -106,6 +130,15 @@ func deletePartyGuest(guestHandler *guest.GuestHandler, token string, test *test assertEquals(test, response.Result().StatusCode, 200) } +func findGuest(guests []guest.Guest) *guest.Guest { + for _, guest := range guests { + if guest.ID == 1 { + return &guest + } + } + return nil +} + func assertEquals(test testing.TB, actual any, expected any) { test.Helper() if actual != expected { @@ -113,17 +146,21 @@ func assertEquals(test testing.TB, actual any, expected any) { } } -func getCredentials() string { +func getName() string { return "{ \"firstName\": \"Michael\", \"lastName\": \"Hunteman\"}" } +func getCredentials() string { + return "{ \"username\": \"mhunteman\", \"password\": \"password\"}" +} + func getGuestWithParty() string { return `{"id":1,"firstName":"Michael","lastName":"Hunteman", "attendance":"accept", - "email":"mhunteman@yahoo.com","message":"We'll be there!","partySize":2, + "email":"mhunteman@cox.net","message":"We'll be there!","partySize":2, "partyList":[{"firstName":"Madison","lastName":"Rossitto"}]}` } func getGuestWithoutParty() string { - return `{"id":1,"firstName":"Michael","lastName":"Hunteman","attendance":"decline", + return `{"id":1,"firstName":"Michael","lastName":"Hunteman","attendance":"", "email":"","message":"","partySize":1,"partyList":[]}` } |