diff options
author | Michael Hunteman <huntemanmt@gmail.com> | 2025-02-02 18:57:30 -0600 |
---|---|---|
committer | Michael Hunteman <huntemanmt@gmail.com> | 2025-02-02 19:04:24 -0600 |
commit | 5fffbba3b851f6cebfd0e616bef2ff6f0c520c3d (patch) | |
tree | fe8b7a5ba77f83f7b82753d5cc58cba51596da2b | |
parent | 23bcef02052c45089358d22d0645ceac858de3bb (diff) |
-rw-r--r-- | server/Containerfile | 1 | ||||
-rw-r--r-- | server/admin/handler.go | 41 | ||||
-rw-r--r-- | server/admin/models.go | 15 | ||||
-rw-r--r-- | server/admin/store.go | 15 | ||||
-rw-r--r-- | server/cmd/main.go | 38 | ||||
-rw-r--r-- | server/errors/models.go | 6 | ||||
-rw-r--r-- | server/go.mod | 2 | ||||
-rw-r--r-- | server/go.sum | 12 | ||||
-rw-r--r-- | server/guest/handler.go | 55 | ||||
-rw-r--r-- | server/guest/models.go | 22 | ||||
-rw-r--r-- | server/guest/store.go | 73 | ||||
-rw-r--r-- | server/middleware/cors.go | 36 | ||||
-rw-r--r-- | server/middleware/log.go (renamed from server/middleware/logging.go) | 8 |
13 files changed, 150 insertions, 174 deletions
diff --git a/server/Containerfile b/server/Containerfile index dbd8102..39b4ae8 100644 --- a/server/Containerfile +++ b/server/Containerfile @@ -6,7 +6,6 @@ RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o server cmd/main.go # Stage 2 FROM alpine:latest COPY --from=builder /app/server /app/server -# COPY server /app/server WORKDIR /app EXPOSE 8080 CMD ["./server"]
\ No newline at end of file diff --git a/server/admin/handler.go b/server/admin/handler.go index b8f1d7f..2ae0b0d 100644 --- a/server/admin/handler.go +++ b/server/admin/handler.go @@ -12,15 +12,15 @@ import ( ) type AdminHandler struct { - adminStore adminStore + adminStore AdminStore guestStore guest.GuestStore } -type adminStore interface { - Find(admin Admin) (Admin, error) +type AdminStore interface { + Find(Admin) (Admin, error) } -func NewAdminHandler(a adminStore, g guest.GuestStore) *AdminHandler { +func NewAdminHandler(a AdminStore, g guest.GuestStore) *AdminHandler { return &AdminHandler{a, g} } @@ -45,21 +45,19 @@ func (a *AdminHandler) handleLogIn(w http.ResponseWriter, r *http.Request) { } func (a *AdminHandler) logIn(r *http.Request) ([]byte, *errors.AppError) { - requestAdmin, err := a.decodeCredentials(r) + admin, err := a.decodeCredentials(r) if err != nil { return nil, errors.NewAppError(http.StatusBadRequest, err.Error()) } - _, err = a.adminStore.Find(requestAdmin) + _, err = a.adminStore.Find(admin) if err != nil { return nil, errors.NewAppError(http.StatusUnauthorized, err.Error()) } - expirationTime := a.setExpirationTime() - claims := a.createClaims(requestAdmin, expirationTime) key, err := a.readKey() if err != nil { return nil, errors.NewAppError(http.StatusInternalServerError, err.Error()) } - token, err := a.createToken(claims, key) + token, err := a.newToken(NewClaims(admin, a.setExpirationTime()), key) if err != nil { return nil, errors.NewAppError(http.StatusInternalServerError, err.Error()) } @@ -85,31 +83,14 @@ func (a *AdminHandler) setExpirationTime() time.Time { return time.Now().Add(15 * time.Minute) } -func (a *AdminHandler) createClaims(admin Admin, expirationTime time.Time) *Claims { - return &Claims{ - admin, - jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expirationTime), - }, - } -} - func (a *AdminHandler) readKey() ([]byte, error) { return os.ReadFile(os.Getenv("ADMIN_KEY")) } -func (a *AdminHandler) createToken(claims *Claims, key []byte) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(key) -} - -func (a *AdminHandler) marshalResponse(guests []guest.Guest, - token string) ([]byte, error) { - loginResponse := a.createLoginResponse(guests, token) - return json.Marshal(loginResponse) +func (a *AdminHandler) newToken(claims *Claims, key []byte) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(key) } -func (a *AdminHandler) createLoginResponse(guests []guest.Guest, - token string) *Login { - return &Login{guests, token} +func (a *AdminHandler) marshalResponse(guests []guest.Guest, token string) ([]byte, error) { + return json.Marshal(NewLogin(guests, token)) } diff --git a/server/admin/models.go b/server/admin/models.go index 275f617..524099b 100644 --- a/server/admin/models.go +++ b/server/admin/models.go @@ -1,6 +1,8 @@ package admin import ( + "time" + "git.huntm.net/wedding/server/guest" "github.com/golang-jwt/jwt/v5" ) @@ -20,3 +22,16 @@ type Login struct { Guests []guest.Guest `json:"guests"` Token string `json:"token"` } + +func NewClaims(admin Admin, expirationTime time.Time) *Claims { + return &Claims{ + admin, + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expirationTime), + }, + } +} + +func NewLogin(guests []guest.Guest, token string) *Login { + return &Login{guests, token} +} diff --git a/server/admin/store.go b/server/admin/store.go index 3322b35..9a7f639 100644 --- a/server/admin/store.go +++ b/server/admin/store.go @@ -19,14 +19,13 @@ func NewStore(database *pgxpool.Pool) *Store { } } -func (store Store) Find(requestAdmin Admin) (Admin, error) { - adminRows, err := store.database.Query(context.Background(), - "select * from admin") +func (s *Store) Find(requestAdmin Admin) (Admin, error) { + adminRows, err := s.database.Query(context.Background(), "select * from admin") if err != nil { return Admin{}, err } defer adminRows.Close() - admin, found := createAdmin(requestAdmin, adminRows) + admin, found := s.newAdmin(requestAdmin, adminRows) if found { return admin, nil @@ -34,7 +33,7 @@ func (store Store) Find(requestAdmin Admin) (Admin, error) { return Admin{}, errors.New("invalid username or password") } -func createAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) { +func (s *Store) newAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) { var databaseAdmin Admin for adminRows.Next() { err := adminRows.Scan(&databaseAdmin.Id, &databaseAdmin.Username, &databaseAdmin.Password) @@ -42,13 +41,13 @@ func createAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) { return Admin{}, false } if databaseAdmin.Username == requestAdmin.Username && - verifyPassword(databaseAdmin.Password, requestAdmin.Password) { + s.verifyPassword(databaseAdmin.Password, requestAdmin.Password) { return databaseAdmin, true } } return Admin{}, false } -func verifyPassword(hashedPassword string, password string) bool { - return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) == nil +func (s *Store) verifyPassword(hash string, password string) bool { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil } diff --git a/server/cmd/main.go b/server/cmd/main.go index 7b2310a..823d8d4 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -6,7 +6,6 @@ import ( "log" "net/http" "os" - "slices" "git.huntm.net/wedding/server/admin" "git.huntm.net/wedding/server/guest" @@ -38,40 +37,5 @@ func main() { mux := http.NewServeMux() mux.Handle("/api/guests/", guestHandler) mux.Handle("/api/admin/", adminHandler) - log.Fatal(http.ListenAndServe(":8080", serveHTTP(middleware.LoggingMiddleware(mux)))) -} - -func serveHTTP(handler http.Handler) http.Handler { - return http.HandlerFunc(func(responseWriter http.ResponseWriter, - request *http.Request) { - writeMethods(responseWriter, request) - writeOrigins(responseWriter, request) - writeHeaders(responseWriter) - handler.ServeHTTP(responseWriter, request) - }) -} - -func writeMethods(responseWriter http.ResponseWriter, request *http.Request) { - allowedMethods := []string{"OPTIONS", "POST", "PUT", "GET", "DELETE"} - method := request.Header.Get("Access-Control-Request-Method") - if isPreflight(request) && slices.Contains(allowedMethods, method) { - responseWriter.Header().Add("Access-Control-Allow-Methods", method) - } -} - -func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) { - origin := request.Header.Get("Origin") - if origin == "http://localhost:5173" { - responseWriter.Header().Add("Access-Control-Allow-Origin", origin) - } -} - -func writeHeaders(responseWriter http.ResponseWriter) { - responseWriter.Header().Add("Access-Control-Allow-Headers", "*") -} - -func isPreflight(request *http.Request) bool { - return request.Method == "OPTIONS" && - request.Header.Get("Origin") != "" && - request.Header.Get("Access-Control-Request-Method") != "" + log.Fatal(http.ListenAndServe(":8080", middleware.Log(middleware.CORS(mux)))) } diff --git a/server/errors/models.go b/server/errors/models.go index 2b90d21..1104c4f 100644 --- a/server/errors/models.go +++ b/server/errors/models.go @@ -10,9 +10,9 @@ type AppError struct { } func NewAppError(status int, msg string) *AppError { - msgJson, err := json.Marshal(map[string]string{"message": msg}) + msgJSON, err := json.Marshal(map[string]string{"message": msg}) if err != nil { - msgJson, _ = json.Marshal(map[string]string{"message": err.Error()}) + msgJSON, _ = json.Marshal(map[string]string{"message": err.Error()}) } - return &AppError{status, msgJson} + return &AppError{status, msgJSON} } diff --git a/server/go.mod b/server/go.mod index 6e62f40..3011aa8 100644 --- a/server/go.mod +++ b/server/go.mod @@ -5,6 +5,7 @@ go 1.22.2 require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/jackc/pgx/v5 v5.7.2 + golang.org/x/crypto v0.32.0 ) require ( @@ -12,7 +13,6 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/stretchr/testify v1.8.4 // indirect - golang.org/x/crypto v0.32.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/text v0.21.0 // indirect ) diff --git a/server/go.sum b/server/go.sum index 3465d95..89cb14c 100644 --- a/server/go.sum +++ b/server/go.sum @@ -5,16 +5,10 @@ github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17w github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= -github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -24,16 +18,10 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/server/guest/handler.go b/server/guest/handler.go index f596b05..86ffa84 100644 --- a/server/guest/handler.go +++ b/server/guest/handler.go @@ -21,11 +21,11 @@ type GuestHandler struct { } type GuestStore interface { - Find(name Name) (Guest, error) + Find(Name) (Guest, error) Get() ([]Guest, error) - Add(guest Guest) error - Update(guest Guest) error - Delete(id string) error + Add(Guest) error + Update(Guest) error + Delete(string) error } func NewGuestHandler(s GuestStore) *GuestHandler { @@ -87,9 +87,8 @@ func (g *GuestHandler) handlePost(w http.ResponseWriter, r *http.Request) { } } -func (g *GuestHandler) handleDelete(w http.ResponseWriter, - request *http.Request) { - if err := g.deleteGuest(request); err != nil { +func (g *GuestHandler) handleDelete(w http.ResponseWriter, r *http.Request) { + if err := g.deleteGuest(r); err != nil { http.Error(w, string(err.Message), err.Status) } else { w.WriteHeader(http.StatusOK) @@ -105,13 +104,11 @@ func (g *GuestHandler) logIn(r *http.Request) ([]byte, *errors.AppError) { if err != nil { return nil, errors.NewAppError(http.StatusUnauthorized, err.Error()) } - expirationTime := g.setExpirationTime() - claims := g.createClaims(name, expirationTime) key, err := g.readGuestKey() if err != nil { return nil, errors.NewAppError(http.StatusInternalServerError, err.Error()) } - token, err := g.createToken(claims, key) + token, err := g.newToken(NewClaims(name, g.setExpirationTime()), key) if err != nil { return nil, errors.NewAppError(http.StatusInternalServerError, err.Error()) } @@ -133,15 +130,6 @@ func (g *GuestHandler) setExpirationTime() time.Time { return time.Now().Add(15 * time.Minute) } -func (g *GuestHandler) createClaims(name Name, time time.Time) *Claims { - return &Claims{ - Name: name, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time), - }, - } -} - func (g *GuestHandler) readGuestKey() ([]byte, error) { return os.ReadFile(os.Getenv("GUEST_KEY")) } @@ -150,21 +138,12 @@ func (g *GuestHandler) readAdminKey() ([]byte, error) { return os.ReadFile(os.Getenv("ADMIN_KEY")) } -func (g *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(key) +func (g *GuestHandler) newToken(claims *Claims, key []byte) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(key) } func (g *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) { - loginResponse := g.createLoginResponse(guest, token) - return json.Marshal(loginResponse) -} - -func (g *GuestHandler) createLoginResponse(guest Guest, token string) *Login { - return &Login{ - Guest: guest, - Token: token, - } + return json.Marshal(NewLogin(guest, token)) } func (g *GuestHandler) putGuest(r *http.Request) *errors.AppError { @@ -189,9 +168,7 @@ func (g *GuestHandler) putGuest(r *http.Request) *errors.AppError { } func (g *GuestHandler) validateToken(r *http.Request, key []byte) *errors.AppError { - authorizationHeader := g.getToken(r) - claims := g.newClaims() - token, err := g.parseWithClaims(authorizationHeader, claims, key) + token, err := g.parseWithClaims(g.getToken(r), g.newClaims(), key) if err != nil { if err == jwt.ErrSignatureInvalid { return errors.NewAppError(http.StatusUnauthorized, err.Error()) @@ -212,16 +189,14 @@ func (g *GuestHandler) newClaims() *Claims { return &Claims{} } -func (g *GuestHandler) parseWithClaims(token string, claims *Claims, - key []byte) (*jwt.Token, error) { +func (g *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) { return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { return key, nil }) } func (g *GuestHandler) findId(r *http.Request) bool { - matches := guestIdRegex.FindStringSubmatch(r.URL.Path) - return len(matches) < 2 + return len(guestIdRegex.FindStringSubmatch(r.URL.Path)) < 2 } func (g *GuestHandler) decodeGuest(r *http.Request) (Guest, error) { @@ -295,9 +270,7 @@ func (g *GuestHandler) deleteGuest(r *http.Request) *errors.AppError { if g.findId(r) { return errors.NewAppError(http.StatusNotFound, "cannot delete guest that does not exist") } - guestId := getId(r) - err = g.store.Delete(guestId) - if err != nil { + if err := g.store.Delete(getId(r)); err != nil { return errors.NewAppError(http.StatusInternalServerError, err.Error()) } return nil diff --git a/server/guest/models.go b/server/guest/models.go index 280ee97..e484569 100644 --- a/server/guest/models.go +++ b/server/guest/models.go @@ -1,6 +1,10 @@ package guest -import "github.com/golang-jwt/jwt/v5" +import ( + "time" + + "github.com/golang-jwt/jwt/v5" +) type Guest struct { Id string `json:"id"` @@ -27,3 +31,19 @@ type Login struct { Guest Guest `json:"guest"` Token string `json:"token"` } + +func NewClaims(name Name, time time.Time) *Claims { + return &Claims{ + Name: name, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time), + }, + } +} + +func NewLogin(guest Guest, token string) *Login { + return &Login{ + Guest: guest, + Token: token, + } +} diff --git a/server/guest/store.go b/server/guest/store.go index 56a8c2c..e111ea0 100644 --- a/server/guest/store.go +++ b/server/guest/store.go @@ -18,23 +18,21 @@ func NewStore(database *pgxpool.Pool) *Store { } } -func (store Store) Find(name Name) (Guest, error) { - guestRows, err := store.database.Query(context.Background(), - "select * from guest") +func (s *Store) Find(name Name) (Guest, error) { + guestRows, err := s.database.Query(context.Background(), "select * from guest") if err != nil { return Guest{}, err } defer guestRows.Close() - guest, found := createGuest(name, guestRows) + guest, found := s.createGuest(name, guestRows) - partyRows, err := store.database.Query(context.Background(), - "select * from party") + partyRows, err := s.database.Query(context.Background(), "select * from party") if err != nil { return Guest{}, err } defer partyRows.Close() - guest, err = addParty(guest, partyRows) + guest, err = s.addParty(guest, partyRows) if err != nil { return Guest{}, err } @@ -45,7 +43,7 @@ func (store Store) Find(name Name) (Guest, error) { return Guest{}, errors.New("guest not found") } -func createGuest(name Name, guestRows pgx.Rows) (Guest, bool) { +func (s *Store) createGuest(name Name, guestRows pgx.Rows) (Guest, bool) { var guest Guest for guestRows.Next() { err := guestRows.Scan(&guest.Id, &guest.FirstName, &guest.LastName, @@ -61,7 +59,7 @@ func createGuest(name Name, guestRows pgx.Rows) (Guest, bool) { return Guest{}, false } -func addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) { +func (s *Store) addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) { guestWithParty := guestWithoutParty for partyRows.Next() { var guestId string @@ -77,34 +75,32 @@ func addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) { return guestWithParty, nil } -func (store Store) Get() ([]Guest, error) { - guestRows, err := store.database.Query(context.Background(), - "select * from guest") +func (s *Store) Get() ([]Guest, error) { + guestRows, err := s.database.Query(context.Background(), "select * from guest") if err != nil { return nil, err } defer guestRows.Close() - guestsWithoutParty, err := store.createGuestSlice(guestRows) + guestsWithoutParty, err := s.createGuestSlice(guestRows) if err != nil { return []Guest{}, err } - partyRows, err := store.database.Query(context.Background(), - "select * from party") + partyRows, err := s.database.Query(context.Background(), "select * from party") if err != nil { return []Guest{}, err } defer partyRows.Close() - guestsWithParty, err := addPartySlice(guestsWithoutParty, partyRows) + guestsWithParty, err := s.addPartySlice(guestsWithoutParty, partyRows) if err != nil { return []Guest{}, err } return guestsWithParty, nil } -func (store Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) { +func (s *Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) { guests := []Guest{} for guestRows.Next() { var guest Guest @@ -118,8 +114,7 @@ func (store Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) { return guests, nil } -func addPartySlice(guestsWithoutParty []Guest, - partyRows pgx.Rows) ([]Guest, error) { +func (s *Store) addPartySlice(guestsWithoutParty []Guest, partyRows pgx.Rows) ([]Guest, error) { guestsWithParty := guestsWithoutParty for partyRows.Next() { var guestId string @@ -137,51 +132,51 @@ func addPartySlice(guestsWithoutParty []Guest, return guestsWithParty, nil } -func (store Store) Add(guest Guest) error { - if err := store.insertGuest(guest); err != nil { +func (s *Store) Add(guest Guest) error { + if err := s.insertGuest(guest); err != nil { return err } - return store.insertParty(guest) + return s.insertParty(guest) } -func (store Store) insertGuest(guest Guest) error { +func (s *Store) insertGuest(guest Guest) error { statement := `insert into guest (id, first_name, last_name, attendance, email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)` - _, err := store.database.Exec(context.Background(), statement, guest.Id, + _, err := s.database.Exec(context.Background(), statement, guest.Id, guest.FirstName, guest.LastName, guest.Attendance, guest.Email, guest.Message, guest.PartySize) return err } -func (store Store) Update(guest Guest) error { - if err := store.updateGuest(guest); err != nil { +func (s *Store) Update(guest Guest) error { + if err := s.updateGuest(guest); err != nil { return err } - if err := store.deleteParty(guest.Id); err != nil { + if err := s.deleteParty(guest.Id); err != nil { return err } - return store.insertParty(guest) + return s.insertParty(guest) } -func (store Store) updateGuest(guest Guest) error { +func (s *Store) updateGuest(guest Guest) error { statement := `update guest set attendance = $1, email = $2, message = $3, party_size = $4 where id = $5` - _, err := store.database.Exec(context.Background(), statement, + _, err := s.database.Exec(context.Background(), statement, guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.Id) return err } -func (store Store) deleteParty(guestId string) error { +func (s *Store) deleteParty(guestId string) error { statement := "delete from party where guest_id = $1" - _, err := store.database.Exec(context.Background(), statement, guestId) + _, err := s.database.Exec(context.Background(), statement, guestId) return err } -func (store Store) insertParty(guest Guest) error { +func (s *Store) insertParty(guest Guest) error { statement := `insert into party (guest_id, first_name, last_name) values ($1, $2, $3)` for _, pg := range guest.PartyList { - _, err := store.database.Exec(context.Background(), statement, guest.Id, + _, err := s.database.Exec(context.Background(), statement, guest.Id, pg.FirstName, pg.LastName) if err != nil { return err @@ -190,18 +185,18 @@ func (store Store) insertParty(guest Guest) error { return nil } -func (store Store) Delete(guestId string) error { - if err := store.deleteGuest(guestId); err != nil { +func (s *Store) Delete(guestId string) error { + if err := s.deleteGuest(guestId); err != nil { return err } - if err := store.deleteParty(guestId); err != nil { + if err := s.deleteParty(guestId); err != nil { return err } return nil } -func (store Store) deleteGuest(guestId string) error { +func (s *Store) deleteGuest(guestId string) error { statement := "delete from guest where id = $1" - _, err := store.database.Exec(context.Background(), statement, guestId) + _, err := s.database.Exec(context.Background(), statement, guestId) return err } diff --git a/server/middleware/cors.go b/server/middleware/cors.go new file mode 100644 index 0000000..641113a --- /dev/null +++ b/server/middleware/cors.go @@ -0,0 +1,36 @@ +package middleware + +import "net/http" + +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeMethods(w, r) + writeOrigins(w, r) + writeHeaders(w) + next.ServeHTTP(w, r) + }) +} + +func writeMethods(responseWriter http.ResponseWriter, request *http.Request) { + method := request.Header.Get("Access-Control-Request-Method") + if isPreflight(request) { + responseWriter.Header().Add("Access-Control-Allow-Methods", method) + } +} + +func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) { + origin := request.Header.Get("Origin") + if origin == "http://localhost:5173" { + responseWriter.Header().Add("Access-Control-Allow-Origin", origin) + } +} + +func writeHeaders(responseWriter http.ResponseWriter) { + responseWriter.Header().Add("Access-Control-Allow-Headers", "*") +} + +func isPreflight(request *http.Request) bool { + return request.Method == "OPTIONS" && + request.Header.Get("Origin") != "" && + request.Header.Get("Access-Control-Request-Method") != "" +} diff --git a/server/middleware/logging.go b/server/middleware/log.go index d91a5c8..a872daa 100644 --- a/server/middleware/logging.go +++ b/server/middleware/log.go @@ -25,7 +25,13 @@ func (w *LoggingResponseWriter) Write(b []byte) (int, error) { return w.ResponseWriter.Write(b) } -func LoggingMiddleware(next http.Handler) http.Handler { +func (w *LoggingResponseWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func Log(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() |