summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/admin/handler.go119
-rw-r--r--server/admin/models.go21
-rw-r--r--server/admin/store.go49
-rw-r--r--server/cmd/main.go58
-rw-r--r--server/go.mod2
-rw-r--r--server/go.sum4
-rw-r--r--server/guest/handler.go115
-rw-r--r--server/guest/models.go2
-rw-r--r--server/guest/store.go204
-rw-r--r--server/schema.sql8
-rw-r--r--server/test/guest_test.go89
11 files changed, 488 insertions, 183 deletions
diff --git a/server/admin/handler.go b/server/admin/handler.go
new file mode 100644
index 0000000..5fd1fee
--- /dev/null
+++ b/server/admin/handler.go
@@ -0,0 +1,119 @@
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "os"
+ "time"
+
+ "git.huntm.net/wedding/server/guest"
+ "github.com/golang-jwt/jwt/v5"
+)
+
+type AdminHandler struct {
+ adminStore adminStore
+ guestStore guest.GuestStore
+}
+
+type adminStore interface {
+ Find(admin Admin) (Admin, error)
+}
+
+type appError struct {
+ Error error
+ Message string
+ Code int
+}
+
+func NewAdminHandler(adminStore adminStore, guestStore guest.GuestStore) *AdminHandler {
+ return &AdminHandler{adminStore, guestStore}
+}
+
+func (adminHandler *AdminHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
+ switch {
+ case request.Method == http.MethodOptions:
+ responseWriter.WriteHeader(http.StatusOK)
+ case request.Method == http.MethodPost && request.URL.Path == "/admin/login":
+ adminHandler.handleLogIn(responseWriter, request)
+ default:
+ responseWriter.WriteHeader(http.StatusNotFound)
+ }
+}
+
+func (adminHandler *AdminHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
+ token, err := adminHandler.logIn(request)
+ if err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.Write(token)
+ }
+}
+
+func (adminHandler *AdminHandler) logIn(request *http.Request) ([]byte, *appError) {
+ requestAdmin, err := adminHandler.decodeCredentials(request)
+ if err != nil {
+ return []byte{}, &appError{err, "failed to unmarshal request", http.StatusBadRequest}
+ }
+ _, err = adminHandler.adminStore.Find(requestAdmin)
+ if err != nil {
+ return []byte{}, &appError{err, "admin not found", http.StatusUnauthorized}
+ }
+ expirationTime := adminHandler.setExpirationTime()
+ claims := adminHandler.createClaims(requestAdmin, expirationTime)
+ key, err := adminHandler.readKey()
+ if err != nil {
+ return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError}
+ }
+ token, err := adminHandler.createToken(claims, key)
+ if err != nil {
+ return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError}
+ }
+ guests, err := adminHandler.guestStore.Get()
+ if err != nil {
+ return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError}
+ }
+ jsonBytes, err := adminHandler.marshalResponse(guests, token)
+ if err != nil {
+ return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError}
+ }
+ return jsonBytes, nil
+}
+
+func (adminHandler *AdminHandler) decodeCredentials(request *http.Request) (Admin, error) {
+ var admin Admin
+ err := json.NewDecoder(request.Body).Decode(&admin)
+ defer request.Body.Close()
+ return admin, err
+}
+
+func (adminHandler *AdminHandler) setExpirationTime() time.Time {
+ return time.Now().Add(15 * time.Minute)
+}
+
+func (adminHandler *AdminHandler) createClaims(admin Admin, expirationTime time.Time) *Claims {
+ return &Claims{
+ admin,
+ jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(expirationTime),
+ },
+ }
+}
+
+func (adminHandler *AdminHandler) readKey() ([]byte, error) {
+ // TODO: use properties file
+ return os.ReadFile("C:\\Users\\mhunt\\admin.pem")
+}
+
+func (adminHandler *AdminHandler) createToken(claims *Claims, key []byte) (string, error) {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ return token.SignedString(key)
+}
+
+func (adminHandler *AdminHandler) marshalResponse(guests []guest.Guest, token string) ([]byte, error) {
+ loginResponse := adminHandler.createLoginResponse(guests, token)
+ return json.Marshal(loginResponse)
+}
+
+func (adminHandler *AdminHandler) createLoginResponse(guests []guest.Guest, token string) *LoginResponse {
+ return &LoginResponse{guests, token}
+}
diff --git a/server/admin/models.go b/server/admin/models.go
new file mode 100644
index 0000000..d9b8232
--- /dev/null
+++ b/server/admin/models.go
@@ -0,0 +1,21 @@
+package admin
+
+import (
+ "git.huntm.net/wedding/server/guest"
+ "github.com/golang-jwt/jwt/v5"
+)
+
+type Admin struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+}
+
+type Claims struct {
+ Admin Admin `json:"admin"`
+ jwt.RegisteredClaims
+}
+
+type LoginResponse struct {
+ Guests []guest.Guest `json:"guests"`
+ Token string `json:"token"`
+}
diff --git a/server/admin/store.go b/server/admin/store.go
new file mode 100644
index 0000000..437e6af
--- /dev/null
+++ b/server/admin/store.go
@@ -0,0 +1,49 @@
+package admin
+
+import (
+ "context"
+ "errors"
+
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgxpool"
+)
+
+type Store struct {
+ database *pgxpool.Pool
+}
+
+func NewStore(database *pgxpool.Pool) *Store {
+ return &Store{
+ database,
+ }
+}
+
+func (store Store) Find(requestAdmin Admin) (Admin, error) {
+ adminRows, err := store.database.Query(context.Background(),
+ "select * from admin")
+ if err != nil {
+ return Admin{}, err
+ }
+ defer adminRows.Close()
+ admin, found := createAdmin(requestAdmin, adminRows)
+
+ if found {
+ return admin, nil
+ }
+ return Admin{}, errors.New("admin not found")
+}
+
+func createAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) {
+ var databaseAdmin Admin
+ for adminRows.Next() {
+ err := adminRows.Scan(&databaseAdmin.Username, &databaseAdmin.Password)
+ if err != nil {
+ return Admin{}, false
+ }
+ if databaseAdmin.Username == requestAdmin.Username &&
+ databaseAdmin.Password == requestAdmin.Password {
+ return databaseAdmin, true
+ }
+ }
+ return Admin{}, false
+}
diff --git a/server/cmd/main.go b/server/cmd/main.go
index 439ed0b..38ccf37 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -10,53 +10,65 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
+ "git.huntm.net/wedding/server/admin"
"git.huntm.net/wedding/server/guest"
)
var (
user = os.Getenv("USER")
- pass = os.Getenv("PASS")
+ password = os.Getenv("PASS")
host = "localhost"
port = "5432"
database = "postgres"
)
func main() {
- db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database))
+ databasePool, err := pgxpool.New(context.Background(),
+ fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, host, port, database))
if err != nil {
log.Fatal(err)
}
- defer db.Close()
+ defer databasePool.Close()
- store := guest.NewMemStore(db)
- guestHandler := guest.NewGuestHandler(store)
+ guestStore := guest.NewStore(databasePool)
+ guestHandler := guest.NewGuestHandler(guestStore)
+ adminStore := admin.NewStore(databasePool)
+ adminHandler := admin.NewAdminHandler(adminStore, guestStore)
mux := http.NewServeMux()
mux.Handle("/guest/", guestHandler)
- log.Fatal(http.ListenAndServe(":8080", enableCORS(mux)))
+ mux.Handle("/admin/", adminHandler)
+ log.Fatal(http.ListenAndServe(":8080", serveHTTP(mux)))
}
-func enableCORS(handler http.Handler) http.Handler {
- allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.41:5173"}
- allowedMethods := []string{"OPTIONS", "POST", "PUT"}
- return serveHTTP(handler, allowedOrigins, allowedMethods)
+func serveHTTP(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(responseWriter http.ResponseWriter,
+ request *http.Request) {
+ writeMethods(responseWriter, request)
+ writeOrigins(responseWriter, request)
+ writeHeaders(responseWriter)
+ handler.ServeHTTP(responseWriter, request)
+ })
}
-func serveHTTP(handler http.Handler, allowedOrigins []string, allowedMethods []string) http.Handler {
- return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) {
- method := request.Header.Get("Access-Control-Request-Method")
- if isPreflight(request) && slices.Contains(allowedMethods, method) {
- responseWriter.Header().Add("Access-Control-Allow-Methods", method)
- }
+func writeMethods(responseWriter http.ResponseWriter, request *http.Request) {
+ allowedMethods := []string{"OPTIONS", "POST", "PUT"}
+ method := request.Header.Get("Access-Control-Request-Method")
+ if isPreflight(request) && slices.Contains(allowedMethods, method) {
+ responseWriter.Header().Add("Access-Control-Allow-Methods", method)
+ }
+}
- origin := request.Header.Get("Origin")
- if slices.Contains(allowedOrigins, origin) {
- responseWriter.Header().Add("Access-Control-Allow-Origin", origin)
- }
+func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) {
+ allowedOrigins := []string{"http://localhost:5173", "http://192.168.1.18:5173"}
+ origin := request.Header.Get("Origin")
+ if slices.Contains(allowedOrigins, origin) {
+ responseWriter.Header().Add("Access-Control-Allow-Origin", origin)
+ }
+}
- responseWriter.Header().Add("Access-Control-Allow-Headers", "*")
- handler.ServeHTTP(responseWriter, request)
- })
+func writeHeaders(responseWriter http.ResponseWriter) {
+ responseWriter.Header().Add("Access-Control-Allow-Headers", "*")
}
func isPreflight(request *http.Request) bool {
diff --git a/server/go.mod b/server/go.mod
index 49af49e..8ae7015 100644
--- a/server/go.mod
+++ b/server/go.mod
@@ -4,7 +4,7 @@ go 1.22.2
require (
github.com/golang-jwt/jwt/v5 v5.2.1
- github.com/jackc/pgx/v5 v5.5.5
+ github.com/jackc/pgx/v5 v5.6.0
)
require (
diff --git a/server/go.sum b/server/go.sum
index b93b74b..590fe16 100644
--- a/server/go.sum
+++ b/server/go.sum
@@ -7,8 +7,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
-github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
-github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
+github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
+github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
diff --git a/server/guest/handler.go b/server/guest/handler.go
index 46b8a45..153a633 100644
--- a/server/guest/handler.go
+++ b/server/guest/handler.go
@@ -12,15 +12,15 @@ import (
)
var (
- guestRe = regexp.MustCompile(`^/guest/*$`)
- guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
+ guestRegex = regexp.MustCompile(`^/guest/*$`)
+ guestIDRegex = regexp.MustCompile(`^/guest/([0-9]+)$`)
)
type GuestHandler struct {
- store guestStore
+ guestStore GuestStore
}
-type guestStore interface {
+type GuestStore interface {
Find(credentials Credentials) (Guest, error)
Get() ([]Guest, error)
Add(guest Guest) error
@@ -33,9 +33,9 @@ type appError struct {
Code int
}
-func NewGuestHandler(s guestStore) *GuestHandler {
+func NewGuestHandler(guestStore GuestStore) *GuestHandler {
return &GuestHandler{
- store: s,
+ guestStore,
}
}
@@ -45,11 +45,11 @@ func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter,
responseWriter.WriteHeader(http.StatusOK)
case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
guestHandler.handleLogIn(responseWriter, request)
- case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
+ case request.Method == http.MethodPut && guestIDRegex.MatchString(request.URL.Path):
guestHandler.handlePut(responseWriter, request)
- case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
+ case request.Method == http.MethodGet && guestRegex.MatchString(request.URL.Path):
guestHandler.handleGet(responseWriter, request)
- case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
+ case request.Method == http.MethodPost && guestIDRegex.MatchString(request.URL.Path):
guestHandler.handlePost(responseWriter, request)
default:
responseWriter.WriteHeader(http.StatusNotFound)
@@ -57,7 +57,7 @@ func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter,
}
func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
- token, err := guestHandler.logInGuest(request)
+ token, err := guestHandler.logIn(request)
if err != nil {
http.Error(responseWriter, err.Message, err.Code)
} else {
@@ -90,28 +90,28 @@ func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter,
}
}
-func (guestHandler *GuestHandler) logInGuest(request *http.Request) ([]byte, *appError) {
+func (guestHandler *GuestHandler) logIn(request *http.Request) ([]byte, *appError) {
credentials, err := guestHandler.decodeCredentials(request)
if err != nil {
- return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest}
+ return []byte{}, &appError{err, "failed to unmarshal credentials", http.StatusBadRequest}
}
- guest, err := guestHandler.store.Find(credentials)
+ guest, err := guestHandler.guestStore.Find(credentials)
if err != nil {
- return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized}
+ return []byte{}, &appError{err, "guest not found", http.StatusUnauthorized}
}
expirationTime := guestHandler.setExpirationTime()
claims := guestHandler.createClaims(credentials, expirationTime)
- key, err := guestHandler.readKey()
+ key, err := guestHandler.readGuestKey()
if err != nil {
- return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError}
+ return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError}
}
token, err := guestHandler.createToken(claims, key)
if err != nil {
- return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError}
+ return []byte{}, &appError{err, "failed to create token", http.StatusInternalServerError}
}
jsonBytes, err := guestHandler.marshalResponse(guest, token)
if err != nil {
- return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError}
+ return []byte{}, &appError{err, "failed to marshal response", http.StatusInternalServerError}
}
return jsonBytes, nil
}
@@ -136,9 +136,13 @@ func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirati
}
}
-func (guestHandler *GuestHandler) readKey() ([]byte, error) {
+func (guestHandler *GuestHandler) readGuestKey() ([]byte, error) {
// TODO: use properties file
- return os.ReadFile("C:\\Users\\mhunt\\skey.pem")
+ return os.ReadFile("C:\\Users\\mhunt\\guest.pem")
+}
+
+func (guestHandler *GuestHandler) readAdminKey() ([]byte, error) {
+ return os.ReadFile("C:\\Users\\mhunt\\admin.pem")
}
func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) {
@@ -159,38 +163,38 @@ func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token
}
func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError {
- if err := guestHandler.validateToken(request); err != nil {
+ guestKey, err := guestHandler.readGuestKey()
+ if err != nil {
+ return &appError{err, "failed to read secret key", http.StatusInternalServerError}
+ }
+ if err := guestHandler.validateToken(request, guestKey); err != nil {
return err
}
- if guestHandler.findId(request) {
- return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound}
+ if guestHandler.findID(request) {
+ return &appError{errors.New("id not found"), "id not found", http.StatusNotFound}
}
guest, err := guestHandler.decodeGuest(request)
if err != nil {
- return &appError{err, "Invalid guest", http.StatusBadRequest}
+ return &appError{err, "invalid guest", http.StatusBadRequest}
}
- if err := guestHandler.store.Update(guest); err != nil {
- return &appError{err, "Failed to update guest", http.StatusInternalServerError}
+ if err := guestHandler.guestStore.Update(guest); err != nil {
+ return &appError{err, "failed to update guest", http.StatusInternalServerError}
}
return nil
}
-func (guestHandler *GuestHandler) validateToken(request *http.Request) *appError {
+func (guestHandler *GuestHandler) validateToken(request *http.Request, key []byte) *appError {
authorizationHeader := guestHandler.getToken(request)
- claims := guestHandler.initializeClaims()
- key, err := guestHandler.readKey()
- if err != nil {
- return &appError{err, "Failed to read secret key", http.StatusInternalServerError}
- }
+ claims := guestHandler.newClaims()
token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key)
if err != nil {
if err == jwt.ErrSignatureInvalid {
- return &appError{err, "Invalid signature", http.StatusUnauthorized}
+ return &appError{err, "invalid signature", http.StatusUnauthorized}
}
- return &appError{err, "Failed to parse claims", http.StatusBadRequest}
+ return &appError{err, "failed to parse claims", http.StatusBadRequest}
}
if !token.Valid {
- return &appError{err, "Invalid token", http.StatusUnauthorized}
+ return &appError{err, "invalid token", http.StatusUnauthorized}
}
return nil
}
@@ -199,7 +203,7 @@ func (guestHandler *GuestHandler) getToken(request *http.Request) string {
return request.Header.Get("Authorization")
}
-func (guestHandler *GuestHandler) initializeClaims() *Claims {
+func (guestHandler *GuestHandler) newClaims() *Claims {
return &Claims{}
}
@@ -209,8 +213,8 @@ func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims,
})
}
-func (guestHandler *GuestHandler) findId(request *http.Request) bool {
- matches := guestIdRe.FindStringSubmatch(request.URL.Path)
+func (guestHandler *GuestHandler) findID(request *http.Request) bool {
+ matches := guestIDRegex.FindStringSubmatch(request.URL.Path)
return len(matches) < 2
}
@@ -222,46 +226,53 @@ func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, err
}
func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) {
- // TODO: check with admin token
- if err := guestHandler.validateToken(request); err != nil {
+ adminKey, err := guestHandler.readAdminKey()
+ if err != nil {
+ return []byte{}, &appError{err, "failed to read secret key", http.StatusInternalServerError}
+ }
+ if err := guestHandler.validateToken(request, adminKey); err != nil {
return []byte{}, err
}
- guests, err := guestHandler.store.Get()
+ guests, err := guestHandler.guestStore.Get()
if err != nil {
- return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError}
+ return []byte{}, &appError{err, "failed to get guests", http.StatusInternalServerError}
}
jsonBytes, err := json.Marshal(guests)
if err != nil {
- return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError}
+ return []byte{}, &appError{err, "failed to marshal guests", http.StatusInternalServerError}
}
return jsonBytes, nil
}
func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError {
- if err := guestHandler.validateToken(request); err != nil {
+ adminKey, err := guestHandler.readAdminKey()
+ if err != nil {
+ return &appError{err, "failed to read secret key", http.StatusInternalServerError}
+ }
+ if err := guestHandler.validateToken(request, adminKey); err != nil {
return err
}
guest, err := guestHandler.decodeGuest(request)
if err != nil {
- return &appError{err, "Invalid guest", http.StatusBadRequest}
+ return &appError{err, "invalid guest", http.StatusBadRequest}
}
- guests, err := guestHandler.store.Get()
+ guests, err := guestHandler.guestStore.Get()
if err != nil {
- return &appError{err, "Failed to get guests", http.StatusInternalServerError}
+ return &appError{err, "failed to get guests", http.StatusInternalServerError}
}
if err := guestHandler.checkExistingGuests(guests, guest); err != nil {
- return &appError{err, "ID already exists", http.StatusConflict}
+ return &appError{err, "id already exists", http.StatusConflict}
}
- if err := guestHandler.store.Add(guest); err != nil {
- return &appError{err, "Failed to add guest", http.StatusInternalServerError}
+ if err := guestHandler.guestStore.Add(guest); err != nil {
+ return &appError{err, "failed to add guest", http.StatusInternalServerError}
}
return nil
}
func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error {
for _, guest := range guests {
- if guest.Id == newGuest.Id {
- return errors.New("ID already exists")
+ if guest.ID == newGuest.ID {
+ return errors.New("id already exists")
}
}
return nil
diff --git a/server/guest/models.go b/server/guest/models.go
index aa2c56f..5cd1b8b 100644
--- a/server/guest/models.go
+++ b/server/guest/models.go
@@ -3,7 +3,7 @@ package guest
import "github.com/golang-jwt/jwt/v5"
type Guest struct {
- Id int `json:"id"`
+ ID int `json:"id"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
Attendance string `json:"attendance"`
diff --git a/server/guest/store.go b/server/guest/store.go
index 1a07161..a5b9374 100644
--- a/server/guest/store.go
+++ b/server/guest/store.go
@@ -4,135 +4,185 @@ import (
"context"
"errors"
+ "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
-type MemStore struct {
- db *pgxpool.Pool
+type Store struct {
+ database *pgxpool.Pool
}
-func NewMemStore(db *pgxpool.Pool) *MemStore {
- return &MemStore{
- db,
+func NewStore(database *pgxpool.Pool) *Store {
+ return &Store{
+ database,
}
}
-func (m MemStore) Find(creds Credentials) (Guest, error) {
- rows, err := m.db.Query(context.Background(), "select * from guest")
- var guest Guest
+func (store Store) Find(credentials Credentials) (Guest, error) {
+ guestRows, err := store.database.Query(context.Background(),
+ "select * from guest")
if err != nil {
- return guest, err
+ return Guest{}, err
}
- defer rows.Close()
+ defer guestRows.Close()
+ guest, found := createGuest(credentials, guestRows)
- found := false
- for rows.Next() {
- err := rows.Scan(&guest.Id, &guest.FirstName, &guest.LastName, &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize)
- if err != nil {
- return guest, err
- }
- if guest.FirstName == creds.FirstName && guest.LastName == creds.LastName {
- found = true
- break
- }
+ partyRows, err := store.database.Query(context.Background(),
+ "select * from party")
+ if err != nil {
+ return Guest{}, err
}
+ defer partyRows.Close()
- rows, err = m.db.Query(context.Background(), "select * from party")
+ guest, err = addParty(guest, partyRows)
if err != nil {
- return guest, err
+ return Guest{}, err
}
- defer rows.Close()
- for rows.Next() {
- var guestId int
- var partyGuest PartyGuest
- err := rows.Scan(&guestId, &partyGuest.FirstName, &partyGuest.LastName)
+ if found {
+ return guest, nil
+ }
+ return Guest{}, errors.New("guest not found")
+}
+
+func createGuest(credentials Credentials, guestRows pgx.Rows) (Guest, bool) {
+ var guest Guest
+ for guestRows.Next() {
+ err := guestRows.Scan(&guest.ID, &guest.FirstName, &guest.LastName,
+ &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize)
if err != nil {
- return guest, err
+ return Guest{}, false
}
- if guestId == guest.Id {
- guest.PartyList = append(guest.PartyList, partyGuest)
+ if guest.FirstName == credentials.FirstName &&
+ guest.LastName == credentials.LastName {
+ return guest, true
}
}
+ return Guest{}, false
+}
- if found {
- return guest, nil
+func addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) {
+ guestWithParty := guestWithoutParty
+ for partyRows.Next() {
+ var guestID int
+ var partyGuest PartyGuest
+ err := partyRows.Scan(&guestID, &partyGuest.FirstName, &partyGuest.LastName)
+ if err != nil {
+ return Guest{}, err
+ }
+ if guestID == guestWithParty.ID {
+ guestWithParty.PartyList = append(guestWithParty.PartyList, partyGuest)
+ }
}
- return guest, errors.New("Guest not found")
+ return guestWithParty, nil
}
-func (m MemStore) Get() ([]Guest, error) {
- rows, err := m.db.Query(context.Background(), "select * from guest")
+func (store Store) Get() ([]Guest, error) {
+ guestRows, err := store.database.Query(context.Background(),
+ "select * from guest")
if err != nil {
return nil, err
}
- defer rows.Close()
+ defer guestRows.Close()
- guestSlice := []Guest{}
- for rows.Next() {
- var guest Guest
- err := rows.Scan(&guest.Id, &guest.FirstName, &guest.LastName, &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize)
- if err != nil {
- return guestSlice, err
- }
- guestSlice = append(guestSlice, guest)
+ guestsWithoutParty, err := store.createGuestSlice(guestRows)
+ if err != nil {
+ return []Guest{}, err
}
- rows, err = m.db.Query(context.Background(), "select * from party")
+ partyRows, err := store.database.Query(context.Background(),
+ "select * from party")
if err != nil {
- return guestSlice, err
+ return []Guest{}, err
}
- defer rows.Close()
+ defer partyRows.Close()
- for rows.Next() {
- var guestId int
+ guestsWithParty, err := addPartySlice(guestsWithoutParty, partyRows)
+ if err != nil {
+ return []Guest{}, err
+ }
+ return guestsWithParty, nil
+}
+
+func (store Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) {
+ guests := []Guest{}
+ for guestRows.Next() {
+ var guest Guest
+ err := guestRows.Scan(&guest.ID, &guest.FirstName, &guest.LastName,
+ &guest.Attendance, &guest.Email, &guest.Message, &guest.PartySize)
+ if err != nil {
+ return []Guest{}, err
+ }
+ guests = append(guests, guest)
+ }
+ return guests, nil
+}
+
+func addPartySlice(guestsWithoutParty []Guest,
+ partyRows pgx.Rows) ([]Guest, error) {
+ guestsWithParty := guestsWithoutParty
+ for partyRows.Next() {
+ var guestID int
var partyGuest PartyGuest
- err := rows.Scan(&guestId, &partyGuest.FirstName, &partyGuest.LastName)
+ err := partyRows.Scan(&guestID, &partyGuest.FirstName, &partyGuest.LastName)
if err != nil {
- return guestSlice, err
+ return []Guest{}, err
}
- for i, g := range guestSlice {
- if guestId == g.Id {
- guestSlice[i].PartyList = append(g.PartyList, partyGuest)
+ for i, guest := range guestsWithParty {
+ if guestID == guest.ID {
+ guestsWithParty[i].PartyList = append(guest.PartyList, partyGuest)
}
}
}
- return guestSlice, nil
+ return guestsWithParty, nil
}
-func (m MemStore) Add(guest Guest) error {
- statement := "insert into guest (id, first_name, last_name, attendance, email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)"
- _, err := m.db.Exec(context.Background(), statement, guest.Id, guest.FirstName, guest.LastName, guest.Attendance, guest.Email, guest.Message, guest.PartySize)
- if err != nil {
+func (store Store) Add(guest Guest) error {
+ if err := store.insertGuest(guest); err != nil {
return err
}
+ return store.insertParty(guest)
+}
- statement = "insert into party (guest_id, first_name, last_name) values ($1, $2, $3)"
- for _, pg := range guest.PartyList {
- _, err = m.db.Exec(context.Background(), statement, guest.Id, pg.FirstName, pg.LastName)
- if err != nil {
- return err
- }
- }
- return nil
+func (store Store) insertGuest(guest Guest) error {
+ statement := `insert into guest (id, first_name, last_name, attendance,
+ email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)`
+ _, err := store.database.Exec(context.Background(), statement, guest.ID,
+ guest.FirstName, guest.LastName, guest.Attendance, guest.Email,
+ guest.Message, guest.PartySize)
+ return err
}
-func (m MemStore) Update(guest Guest) error {
- statement := "update guest set attendance = $1, email = $2, message = $3, party_size = $4 where id = $5"
- _, err := m.db.Exec(context.Background(), statement, guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.Id)
- if err != nil {
+func (store Store) Update(guest Guest) error {
+ if err := store.updateGuest(guest); err != nil {
return err
}
-
- statement = "delete from party where guest_id = $1"
- _, err = m.db.Exec(context.Background(), statement, guest.Id)
- if err != nil {
+ if err := store.deleteOldParty(guest.ID); err != nil {
return err
}
+ return store.insertParty(guest)
+}
+
+func (store Store) updateGuest(guest Guest) error {
+ statement := `update guest set attendance = $1, email = $2, message = $3,
+ party_size = $4 where id = $5`
+ _, err := store.database.Exec(context.Background(), statement,
+ guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.ID)
+ return err
+}
+
+func (store Store) deleteOldParty(guestID int) error {
+ statement := "delete from party where guest_id = $1"
+ _, err := store.database.Exec(context.Background(), statement, guestID)
+ return err
+}
- statement = "insert into party (guest_id, first_name, last_name) values ($1, $2, $3)"
+func (store Store) insertParty(guest Guest) error {
+ statement := `insert into party (guest_id, first_name, last_name)
+ values ($1, $2, $3)`
for _, pg := range guest.PartyList {
- _, err = m.db.Exec(context.Background(), statement, guest.Id, pg.FirstName, pg.LastName)
+ _, err := store.database.Exec(context.Background(), statement, guest.ID,
+ pg.FirstName, pg.LastName)
if err != nil {
return err
}
diff --git a/server/schema.sql b/server/schema.sql
index fdc589a..809d8c9 100644
--- a/server/schema.sql
+++ b/server/schema.sql
@@ -12,4 +12,10 @@ create table party (
guest_id integer NOT NULL references guest(id) ON DELETE CASCADE,
first_name varchar(64),
last_name varchar(64)
-); \ No newline at end of file
+);
+
+create table admin (
+ id serial PRIMARY KEY,
+ username varchar(64),
+ password varchar(64)
+) \ No newline at end of file
diff --git a/server/test/guest_test.go b/server/test/guest_test.go
index c675ddb..206837a 100644
--- a/server/test/guest_test.go
+++ b/server/test/guest_test.go
@@ -11,53 +11,59 @@ import (
"strings"
"testing"
+ "git.huntm.net/wedding/server/admin"
"git.huntm.net/wedding/server/guest"
"github.com/jackc/pgx/v5/pgxpool"
)
var (
user = os.Getenv("USER")
- pass = os.Getenv("PASS")
+ password = os.Getenv("PASS")
host = "localhost"
port = "5432"
database = "postgres"
)
func TestUpdateRSVP(test *testing.T) {
- db, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, database))
+ databasePool, err := pgxpool.New(context.Background(),
+ fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, host, port, database))
if err != nil {
test.Error(err)
}
- defer db.Close()
- store := guest.NewMemStore(db)
- guestHandler := guest.NewGuestHandler(store)
+ defer databasePool.Close()
+ guestStore := guest.NewStore(databasePool)
+ guestHandler := guest.NewGuestHandler(guestStore)
+ adminStore := admin.NewStore(databasePool)
+ adminHandler := admin.NewAdminHandler(adminStore, guestStore)
- login := logIn(guestHandler, test)
- token := login.Token
+ guestLogin := logInGuest(guestHandler, test)
+ guestToken := guestLogin.Token
+ addPartyGuest(guestHandler, guestToken, test)
- addPartyGuest(guestHandler, token, test)
- guestSlice := getGuests(guestHandler, token, test)
- guest := guestSlice[0]
+ adminLogin := logInAdmin(adminHandler, test)
+ adminToken := adminLogin.Token
+ guests := getGuests(guestHandler, adminToken, test)
+ guest := findGuest(guests)
assertEquals(test, guest.Attendance, "accept")
- assertEquals(test, guest.Email, "mhunteman@yahoo.com")
+ assertEquals(test, guest.Email, "mhunteman@cox.net")
assertEquals(test, guest.Message, "We'll be there!")
assertEquals(test, guest.PartySize, 2)
assertEquals(test, guest.PartyList[0].FirstName, "Madison")
assertEquals(test, guest.PartyList[0].LastName, "Rossitto")
- deletePartyGuest(guestHandler, token, test)
- guestSlice = getGuests(guestHandler, token, test)
- guest = guestSlice[0]
- assertEquals(test, guest.Attendance, "decline")
+ deletePartyGuest(guestHandler, guestToken, test)
+ guests = getGuests(guestHandler, adminToken, test)
+ guest = findGuest(guests)
+ assertEquals(test, guest.Attendance, "")
assertEquals(test, guest.Email, "")
assertEquals(test, guest.Message, "")
assertEquals(test, guest.PartySize, 1)
}
-func logIn(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginResponse {
+func logInGuest(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginResponse {
response := httptest.NewRecorder()
- loginRequest, err := http.NewRequest(http.MethodPost, "http://localhost:8080/guest/login",
- strings.NewReader(getCredentials()))
+ loginRequest, err := http.NewRequest(http.MethodPost,
+ "http://localhost:8080/guest/login", strings.NewReader(getName()))
if err != nil {
test.Error(err)
}
@@ -73,7 +79,8 @@ func logIn(guestHandler *guest.GuestHandler, test *testing.T) guest.LoginRespons
func addPartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) {
response := httptest.NewRecorder()
guestWithParty := getGuestWithParty()
- putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", strings.NewReader(guestWithParty))
+ putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1",
+ strings.NewReader(guestWithParty))
if err != nil {
test.Error(err)
}
@@ -82,22 +89,39 @@ func addPartyGuest(guestHandler *guest.GuestHandler, token string, test *testing
assertEquals(test, response.Result().StatusCode, 200)
}
+func logInAdmin(adminHandler *admin.AdminHandler, test *testing.T) admin.LoginResponse {
+ response := httptest.NewRecorder()
+ loginRequest, err := http.NewRequest(http.MethodPost,
+ "http://localhost:8080/admin/login", strings.NewReader(getCredentials()))
+ if err != nil {
+ test.Error(err)
+ }
+ adminHandler.ServeHTTP(response, loginRequest)
+ assertEquals(test, response.Result().StatusCode, 200)
+ var login admin.LoginResponse
+ if err = json.NewDecoder(response.Body).Decode(&login); err != nil {
+ log.Fatal(err)
+ }
+ return login
+}
+
func getGuests(guestHandler *guest.GuestHandler, token string, test *testing.T) []guest.Guest {
response := httptest.NewRecorder()
getRequest, _ := http.NewRequest(http.MethodGet, "http://localhost:8080/guest/", nil)
getRequest.Header.Set("Authorization", token)
guestHandler.ServeHTTP(response, getRequest)
assertEquals(test, response.Result().StatusCode, 200)
- var guestSlice []guest.Guest
- if err := json.NewDecoder(response.Body).Decode(&guestSlice); err != nil {
+ var guests []guest.Guest
+ if err := json.NewDecoder(response.Body).Decode(&guests); err != nil {
test.Error(err)
}
- return guestSlice
+ return guests
}
func deletePartyGuest(guestHandler *guest.GuestHandler, token string, test *testing.T) {
response := httptest.NewRecorder()
- putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1", strings.NewReader(getGuestWithoutParty()))
+ putRequest, err := http.NewRequest(http.MethodPut, "http://localhost:8080/guest/1",
+ strings.NewReader(getGuestWithoutParty()))
if err != nil {
test.Error(err)
}
@@ -106,6 +130,15 @@ func deletePartyGuest(guestHandler *guest.GuestHandler, token string, test *test
assertEquals(test, response.Result().StatusCode, 200)
}
+func findGuest(guests []guest.Guest) *guest.Guest {
+ for _, guest := range guests {
+ if guest.ID == 1 {
+ return &guest
+ }
+ }
+ return nil
+}
+
func assertEquals(test testing.TB, actual any, expected any) {
test.Helper()
if actual != expected {
@@ -113,17 +146,21 @@ func assertEquals(test testing.TB, actual any, expected any) {
}
}
-func getCredentials() string {
+func getName() string {
return "{ \"firstName\": \"Michael\", \"lastName\": \"Hunteman\"}"
}
+func getCredentials() string {
+ return "{ \"username\": \"mhunteman\", \"password\": \"password\"}"
+}
+
func getGuestWithParty() string {
return `{"id":1,"firstName":"Michael","lastName":"Hunteman", "attendance":"accept",
- "email":"mhunteman@yahoo.com","message":"We'll be there!","partySize":2,
+ "email":"mhunteman@cox.net","message":"We'll be there!","partySize":2,
"partyList":[{"firstName":"Madison","lastName":"Rossitto"}]}`
}
func getGuestWithoutParty() string {
- return `{"id":1,"firstName":"Michael","lastName":"Hunteman","attendance":"decline",
+ return `{"id":1,"firstName":"Michael","lastName":"Hunteman","attendance":"",
"email":"","message":"","partySize":1,"partyList":[]}`
}