summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/Containerfile1
-rw-r--r--server/admin/handler.go41
-rw-r--r--server/admin/models.go15
-rw-r--r--server/admin/store.go15
-rw-r--r--server/cmd/main.go38
-rw-r--r--server/errors/models.go6
-rw-r--r--server/go.mod2
-rw-r--r--server/go.sum12
-rw-r--r--server/guest/handler.go55
-rw-r--r--server/guest/models.go22
-rw-r--r--server/guest/store.go73
-rw-r--r--server/middleware/cors.go36
-rw-r--r--server/middleware/log.go (renamed from server/middleware/logging.go)8
13 files changed, 150 insertions, 174 deletions
diff --git a/server/Containerfile b/server/Containerfile
index dbd8102..39b4ae8 100644
--- a/server/Containerfile
+++ b/server/Containerfile
@@ -6,7 +6,6 @@ RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o server cmd/main.go
# Stage 2
FROM alpine:latest
COPY --from=builder /app/server /app/server
-# COPY server /app/server
WORKDIR /app
EXPOSE 8080
CMD ["./server"] \ No newline at end of file
diff --git a/server/admin/handler.go b/server/admin/handler.go
index b8f1d7f..2ae0b0d 100644
--- a/server/admin/handler.go
+++ b/server/admin/handler.go
@@ -12,15 +12,15 @@ import (
)
type AdminHandler struct {
- adminStore adminStore
+ adminStore AdminStore
guestStore guest.GuestStore
}
-type adminStore interface {
- Find(admin Admin) (Admin, error)
+type AdminStore interface {
+ Find(Admin) (Admin, error)
}
-func NewAdminHandler(a adminStore, g guest.GuestStore) *AdminHandler {
+func NewAdminHandler(a AdminStore, g guest.GuestStore) *AdminHandler {
return &AdminHandler{a, g}
}
@@ -45,21 +45,19 @@ func (a *AdminHandler) handleLogIn(w http.ResponseWriter, r *http.Request) {
}
func (a *AdminHandler) logIn(r *http.Request) ([]byte, *errors.AppError) {
- requestAdmin, err := a.decodeCredentials(r)
+ admin, err := a.decodeCredentials(r)
if err != nil {
return nil, errors.NewAppError(http.StatusBadRequest, err.Error())
}
- _, err = a.adminStore.Find(requestAdmin)
+ _, err = a.adminStore.Find(admin)
if err != nil {
return nil, errors.NewAppError(http.StatusUnauthorized, err.Error())
}
- expirationTime := a.setExpirationTime()
- claims := a.createClaims(requestAdmin, expirationTime)
key, err := a.readKey()
if err != nil {
return nil, errors.NewAppError(http.StatusInternalServerError, err.Error())
}
- token, err := a.createToken(claims, key)
+ token, err := a.newToken(NewClaims(admin, a.setExpirationTime()), key)
if err != nil {
return nil, errors.NewAppError(http.StatusInternalServerError, err.Error())
}
@@ -85,31 +83,14 @@ func (a *AdminHandler) setExpirationTime() time.Time {
return time.Now().Add(15 * time.Minute)
}
-func (a *AdminHandler) createClaims(admin Admin, expirationTime time.Time) *Claims {
- return &Claims{
- admin,
- jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(expirationTime),
- },
- }
-}
-
func (a *AdminHandler) readKey() ([]byte, error) {
return os.ReadFile(os.Getenv("ADMIN_KEY"))
}
-func (a *AdminHandler) createToken(claims *Claims, key []byte) (string, error) {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString(key)
-}
-
-func (a *AdminHandler) marshalResponse(guests []guest.Guest,
- token string) ([]byte, error) {
- loginResponse := a.createLoginResponse(guests, token)
- return json.Marshal(loginResponse)
+func (a *AdminHandler) newToken(claims *Claims, key []byte) (string, error) {
+ return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(key)
}
-func (a *AdminHandler) createLoginResponse(guests []guest.Guest,
- token string) *Login {
- return &Login{guests, token}
+func (a *AdminHandler) marshalResponse(guests []guest.Guest, token string) ([]byte, error) {
+ return json.Marshal(NewLogin(guests, token))
}
diff --git a/server/admin/models.go b/server/admin/models.go
index 275f617..524099b 100644
--- a/server/admin/models.go
+++ b/server/admin/models.go
@@ -1,6 +1,8 @@
package admin
import (
+ "time"
+
"git.huntm.net/wedding/server/guest"
"github.com/golang-jwt/jwt/v5"
)
@@ -20,3 +22,16 @@ type Login struct {
Guests []guest.Guest `json:"guests"`
Token string `json:"token"`
}
+
+func NewClaims(admin Admin, expirationTime time.Time) *Claims {
+ return &Claims{
+ admin,
+ jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(expirationTime),
+ },
+ }
+}
+
+func NewLogin(guests []guest.Guest, token string) *Login {
+ return &Login{guests, token}
+}
diff --git a/server/admin/store.go b/server/admin/store.go
index 3322b35..9a7f639 100644
--- a/server/admin/store.go
+++ b/server/admin/store.go
@@ -19,14 +19,13 @@ func NewStore(database *pgxpool.Pool) *Store {
}
}
-func (store Store) Find(requestAdmin Admin) (Admin, error) {
- adminRows, err := store.database.Query(context.Background(),
- "select * from admin")
+func (s *Store) Find(requestAdmin Admin) (Admin, error) {
+ adminRows, err := s.database.Query(context.Background(), "select * from admin")
if err != nil {
return Admin{}, err
}
defer adminRows.Close()
- admin, found := createAdmin(requestAdmin, adminRows)
+ admin, found := s.newAdmin(requestAdmin, adminRows)
if found {
return admin, nil
@@ -34,7 +33,7 @@ func (store Store) Find(requestAdmin Admin) (Admin, error) {
return Admin{}, errors.New("invalid username or password")
}
-func createAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) {
+func (s *Store) newAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) {
var databaseAdmin Admin
for adminRows.Next() {
err := adminRows.Scan(&databaseAdmin.Id, &databaseAdmin.Username, &databaseAdmin.Password)
@@ -42,13 +41,13 @@ func createAdmin(requestAdmin Admin, adminRows pgx.Rows) (Admin, bool) {
return Admin{}, false
}
if databaseAdmin.Username == requestAdmin.Username &&
- verifyPassword(databaseAdmin.Password, requestAdmin.Password) {
+ s.verifyPassword(databaseAdmin.Password, requestAdmin.Password) {
return databaseAdmin, true
}
}
return Admin{}, false
}
-func verifyPassword(hashedPassword string, password string) bool {
- return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) == nil
+func (s *Store) verifyPassword(hash string, password string) bool {
+ return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
}
diff --git a/server/cmd/main.go b/server/cmd/main.go
index 7b2310a..823d8d4 100644
--- a/server/cmd/main.go
+++ b/server/cmd/main.go
@@ -6,7 +6,6 @@ import (
"log"
"net/http"
"os"
- "slices"
"git.huntm.net/wedding/server/admin"
"git.huntm.net/wedding/server/guest"
@@ -38,40 +37,5 @@ func main() {
mux := http.NewServeMux()
mux.Handle("/api/guests/", guestHandler)
mux.Handle("/api/admin/", adminHandler)
- log.Fatal(http.ListenAndServe(":8080", serveHTTP(middleware.LoggingMiddleware(mux))))
-}
-
-func serveHTTP(handler http.Handler) http.Handler {
- return http.HandlerFunc(func(responseWriter http.ResponseWriter,
- request *http.Request) {
- writeMethods(responseWriter, request)
- writeOrigins(responseWriter, request)
- writeHeaders(responseWriter)
- handler.ServeHTTP(responseWriter, request)
- })
-}
-
-func writeMethods(responseWriter http.ResponseWriter, request *http.Request) {
- allowedMethods := []string{"OPTIONS", "POST", "PUT", "GET", "DELETE"}
- method := request.Header.Get("Access-Control-Request-Method")
- if isPreflight(request) && slices.Contains(allowedMethods, method) {
- responseWriter.Header().Add("Access-Control-Allow-Methods", method)
- }
-}
-
-func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) {
- origin := request.Header.Get("Origin")
- if origin == "http://localhost:5173" {
- responseWriter.Header().Add("Access-Control-Allow-Origin", origin)
- }
-}
-
-func writeHeaders(responseWriter http.ResponseWriter) {
- responseWriter.Header().Add("Access-Control-Allow-Headers", "*")
-}
-
-func isPreflight(request *http.Request) bool {
- return request.Method == "OPTIONS" &&
- request.Header.Get("Origin") != "" &&
- request.Header.Get("Access-Control-Request-Method") != ""
+ log.Fatal(http.ListenAndServe(":8080", middleware.Log(middleware.CORS(mux))))
}
diff --git a/server/errors/models.go b/server/errors/models.go
index 2b90d21..1104c4f 100644
--- a/server/errors/models.go
+++ b/server/errors/models.go
@@ -10,9 +10,9 @@ type AppError struct {
}
func NewAppError(status int, msg string) *AppError {
- msgJson, err := json.Marshal(map[string]string{"message": msg})
+ msgJSON, err := json.Marshal(map[string]string{"message": msg})
if err != nil {
- msgJson, _ = json.Marshal(map[string]string{"message": err.Error()})
+ msgJSON, _ = json.Marshal(map[string]string{"message": err.Error()})
}
- return &AppError{status, msgJson}
+ return &AppError{status, msgJSON}
}
diff --git a/server/go.mod b/server/go.mod
index 6e62f40..3011aa8 100644
--- a/server/go.mod
+++ b/server/go.mod
@@ -5,6 +5,7 @@ go 1.22.2
require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/jackc/pgx/v5 v5.7.2
+ golang.org/x/crypto v0.32.0
)
require (
@@ -12,7 +13,6 @@ require (
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/stretchr/testify v1.8.4 // indirect
- golang.org/x/crypto v0.32.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/text v0.21.0 // indirect
)
diff --git a/server/go.sum b/server/go.sum
index 3465d95..89cb14c 100644
--- a/server/go.sum
+++ b/server/go.sum
@@ -5,16 +5,10 @@ github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17w
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
-github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
-github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
-github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
-github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
-github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
-github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -24,16 +18,10 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
-golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
-golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
-golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
-golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/server/guest/handler.go b/server/guest/handler.go
index f596b05..86ffa84 100644
--- a/server/guest/handler.go
+++ b/server/guest/handler.go
@@ -21,11 +21,11 @@ type GuestHandler struct {
}
type GuestStore interface {
- Find(name Name) (Guest, error)
+ Find(Name) (Guest, error)
Get() ([]Guest, error)
- Add(guest Guest) error
- Update(guest Guest) error
- Delete(id string) error
+ Add(Guest) error
+ Update(Guest) error
+ Delete(string) error
}
func NewGuestHandler(s GuestStore) *GuestHandler {
@@ -87,9 +87,8 @@ func (g *GuestHandler) handlePost(w http.ResponseWriter, r *http.Request) {
}
}
-func (g *GuestHandler) handleDelete(w http.ResponseWriter,
- request *http.Request) {
- if err := g.deleteGuest(request); err != nil {
+func (g *GuestHandler) handleDelete(w http.ResponseWriter, r *http.Request) {
+ if err := g.deleteGuest(r); err != nil {
http.Error(w, string(err.Message), err.Status)
} else {
w.WriteHeader(http.StatusOK)
@@ -105,13 +104,11 @@ func (g *GuestHandler) logIn(r *http.Request) ([]byte, *errors.AppError) {
if err != nil {
return nil, errors.NewAppError(http.StatusUnauthorized, err.Error())
}
- expirationTime := g.setExpirationTime()
- claims := g.createClaims(name, expirationTime)
key, err := g.readGuestKey()
if err != nil {
return nil, errors.NewAppError(http.StatusInternalServerError, err.Error())
}
- token, err := g.createToken(claims, key)
+ token, err := g.newToken(NewClaims(name, g.setExpirationTime()), key)
if err != nil {
return nil, errors.NewAppError(http.StatusInternalServerError, err.Error())
}
@@ -133,15 +130,6 @@ func (g *GuestHandler) setExpirationTime() time.Time {
return time.Now().Add(15 * time.Minute)
}
-func (g *GuestHandler) createClaims(name Name, time time.Time) *Claims {
- return &Claims{
- Name: name,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(time),
- },
- }
-}
-
func (g *GuestHandler) readGuestKey() ([]byte, error) {
return os.ReadFile(os.Getenv("GUEST_KEY"))
}
@@ -150,21 +138,12 @@ func (g *GuestHandler) readAdminKey() ([]byte, error) {
return os.ReadFile(os.Getenv("ADMIN_KEY"))
}
-func (g *GuestHandler) createToken(claims *Claims, key []byte) (string, error) {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString(key)
+func (g *GuestHandler) newToken(claims *Claims, key []byte) (string, error) {
+ return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(key)
}
func (g *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) {
- loginResponse := g.createLoginResponse(guest, token)
- return json.Marshal(loginResponse)
-}
-
-func (g *GuestHandler) createLoginResponse(guest Guest, token string) *Login {
- return &Login{
- Guest: guest,
- Token: token,
- }
+ return json.Marshal(NewLogin(guest, token))
}
func (g *GuestHandler) putGuest(r *http.Request) *errors.AppError {
@@ -189,9 +168,7 @@ func (g *GuestHandler) putGuest(r *http.Request) *errors.AppError {
}
func (g *GuestHandler) validateToken(r *http.Request, key []byte) *errors.AppError {
- authorizationHeader := g.getToken(r)
- claims := g.newClaims()
- token, err := g.parseWithClaims(authorizationHeader, claims, key)
+ token, err := g.parseWithClaims(g.getToken(r), g.newClaims(), key)
if err != nil {
if err == jwt.ErrSignatureInvalid {
return errors.NewAppError(http.StatusUnauthorized, err.Error())
@@ -212,16 +189,14 @@ func (g *GuestHandler) newClaims() *Claims {
return &Claims{}
}
-func (g *GuestHandler) parseWithClaims(token string, claims *Claims,
- key []byte) (*jwt.Token, error) {
+func (g *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) {
return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
return key, nil
})
}
func (g *GuestHandler) findId(r *http.Request) bool {
- matches := guestIdRegex.FindStringSubmatch(r.URL.Path)
- return len(matches) < 2
+ return len(guestIdRegex.FindStringSubmatch(r.URL.Path)) < 2
}
func (g *GuestHandler) decodeGuest(r *http.Request) (Guest, error) {
@@ -295,9 +270,7 @@ func (g *GuestHandler) deleteGuest(r *http.Request) *errors.AppError {
if g.findId(r) {
return errors.NewAppError(http.StatusNotFound, "cannot delete guest that does not exist")
}
- guestId := getId(r)
- err = g.store.Delete(guestId)
- if err != nil {
+ if err := g.store.Delete(getId(r)); err != nil {
return errors.NewAppError(http.StatusInternalServerError, err.Error())
}
return nil
diff --git a/server/guest/models.go b/server/guest/models.go
index 280ee97..e484569 100644
--- a/server/guest/models.go
+++ b/server/guest/models.go
@@ -1,6 +1,10 @@
package guest
-import "github.com/golang-jwt/jwt/v5"
+import (
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
type Guest struct {
Id string `json:"id"`
@@ -27,3 +31,19 @@ type Login struct {
Guest Guest `json:"guest"`
Token string `json:"token"`
}
+
+func NewClaims(name Name, time time.Time) *Claims {
+ return &Claims{
+ Name: name,
+ RegisteredClaims: jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(time),
+ },
+ }
+}
+
+func NewLogin(guest Guest, token string) *Login {
+ return &Login{
+ Guest: guest,
+ Token: token,
+ }
+}
diff --git a/server/guest/store.go b/server/guest/store.go
index 56a8c2c..e111ea0 100644
--- a/server/guest/store.go
+++ b/server/guest/store.go
@@ -18,23 +18,21 @@ func NewStore(database *pgxpool.Pool) *Store {
}
}
-func (store Store) Find(name Name) (Guest, error) {
- guestRows, err := store.database.Query(context.Background(),
- "select * from guest")
+func (s *Store) Find(name Name) (Guest, error) {
+ guestRows, err := s.database.Query(context.Background(), "select * from guest")
if err != nil {
return Guest{}, err
}
defer guestRows.Close()
- guest, found := createGuest(name, guestRows)
+ guest, found := s.createGuest(name, guestRows)
- partyRows, err := store.database.Query(context.Background(),
- "select * from party")
+ partyRows, err := s.database.Query(context.Background(), "select * from party")
if err != nil {
return Guest{}, err
}
defer partyRows.Close()
- guest, err = addParty(guest, partyRows)
+ guest, err = s.addParty(guest, partyRows)
if err != nil {
return Guest{}, err
}
@@ -45,7 +43,7 @@ func (store Store) Find(name Name) (Guest, error) {
return Guest{}, errors.New("guest not found")
}
-func createGuest(name Name, guestRows pgx.Rows) (Guest, bool) {
+func (s *Store) createGuest(name Name, guestRows pgx.Rows) (Guest, bool) {
var guest Guest
for guestRows.Next() {
err := guestRows.Scan(&guest.Id, &guest.FirstName, &guest.LastName,
@@ -61,7 +59,7 @@ func createGuest(name Name, guestRows pgx.Rows) (Guest, bool) {
return Guest{}, false
}
-func addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) {
+func (s *Store) addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) {
guestWithParty := guestWithoutParty
for partyRows.Next() {
var guestId string
@@ -77,34 +75,32 @@ func addParty(guestWithoutParty Guest, partyRows pgx.Rows) (Guest, error) {
return guestWithParty, nil
}
-func (store Store) Get() ([]Guest, error) {
- guestRows, err := store.database.Query(context.Background(),
- "select * from guest")
+func (s *Store) Get() ([]Guest, error) {
+ guestRows, err := s.database.Query(context.Background(), "select * from guest")
if err != nil {
return nil, err
}
defer guestRows.Close()
- guestsWithoutParty, err := store.createGuestSlice(guestRows)
+ guestsWithoutParty, err := s.createGuestSlice(guestRows)
if err != nil {
return []Guest{}, err
}
- partyRows, err := store.database.Query(context.Background(),
- "select * from party")
+ partyRows, err := s.database.Query(context.Background(), "select * from party")
if err != nil {
return []Guest{}, err
}
defer partyRows.Close()
- guestsWithParty, err := addPartySlice(guestsWithoutParty, partyRows)
+ guestsWithParty, err := s.addPartySlice(guestsWithoutParty, partyRows)
if err != nil {
return []Guest{}, err
}
return guestsWithParty, nil
}
-func (store Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) {
+func (s *Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) {
guests := []Guest{}
for guestRows.Next() {
var guest Guest
@@ -118,8 +114,7 @@ func (store Store) createGuestSlice(guestRows pgx.Rows) ([]Guest, error) {
return guests, nil
}
-func addPartySlice(guestsWithoutParty []Guest,
- partyRows pgx.Rows) ([]Guest, error) {
+func (s *Store) addPartySlice(guestsWithoutParty []Guest, partyRows pgx.Rows) ([]Guest, error) {
guestsWithParty := guestsWithoutParty
for partyRows.Next() {
var guestId string
@@ -137,51 +132,51 @@ func addPartySlice(guestsWithoutParty []Guest,
return guestsWithParty, nil
}
-func (store Store) Add(guest Guest) error {
- if err := store.insertGuest(guest); err != nil {
+func (s *Store) Add(guest Guest) error {
+ if err := s.insertGuest(guest); err != nil {
return err
}
- return store.insertParty(guest)
+ return s.insertParty(guest)
}
-func (store Store) insertGuest(guest Guest) error {
+func (s *Store) insertGuest(guest Guest) error {
statement := `insert into guest (id, first_name, last_name, attendance,
email, message, party_size) values ($1, $2, $3, $4, $5, $6, $7)`
- _, err := store.database.Exec(context.Background(), statement, guest.Id,
+ _, err := s.database.Exec(context.Background(), statement, guest.Id,
guest.FirstName, guest.LastName, guest.Attendance, guest.Email,
guest.Message, guest.PartySize)
return err
}
-func (store Store) Update(guest Guest) error {
- if err := store.updateGuest(guest); err != nil {
+func (s *Store) Update(guest Guest) error {
+ if err := s.updateGuest(guest); err != nil {
return err
}
- if err := store.deleteParty(guest.Id); err != nil {
+ if err := s.deleteParty(guest.Id); err != nil {
return err
}
- return store.insertParty(guest)
+ return s.insertParty(guest)
}
-func (store Store) updateGuest(guest Guest) error {
+func (s *Store) updateGuest(guest Guest) error {
statement := `update guest set attendance = $1, email = $2, message = $3,
party_size = $4 where id = $5`
- _, err := store.database.Exec(context.Background(), statement,
+ _, err := s.database.Exec(context.Background(), statement,
guest.Attendance, guest.Email, guest.Message, guest.PartySize, guest.Id)
return err
}
-func (store Store) deleteParty(guestId string) error {
+func (s *Store) deleteParty(guestId string) error {
statement := "delete from party where guest_id = $1"
- _, err := store.database.Exec(context.Background(), statement, guestId)
+ _, err := s.database.Exec(context.Background(), statement, guestId)
return err
}
-func (store Store) insertParty(guest Guest) error {
+func (s *Store) insertParty(guest Guest) error {
statement := `insert into party (guest_id, first_name, last_name)
values ($1, $2, $3)`
for _, pg := range guest.PartyList {
- _, err := store.database.Exec(context.Background(), statement, guest.Id,
+ _, err := s.database.Exec(context.Background(), statement, guest.Id,
pg.FirstName, pg.LastName)
if err != nil {
return err
@@ -190,18 +185,18 @@ func (store Store) insertParty(guest Guest) error {
return nil
}
-func (store Store) Delete(guestId string) error {
- if err := store.deleteGuest(guestId); err != nil {
+func (s *Store) Delete(guestId string) error {
+ if err := s.deleteGuest(guestId); err != nil {
return err
}
- if err := store.deleteParty(guestId); err != nil {
+ if err := s.deleteParty(guestId); err != nil {
return err
}
return nil
}
-func (store Store) deleteGuest(guestId string) error {
+func (s *Store) deleteGuest(guestId string) error {
statement := "delete from guest where id = $1"
- _, err := store.database.Exec(context.Background(), statement, guestId)
+ _, err := s.database.Exec(context.Background(), statement, guestId)
return err
}
diff --git a/server/middleware/cors.go b/server/middleware/cors.go
new file mode 100644
index 0000000..641113a
--- /dev/null
+++ b/server/middleware/cors.go
@@ -0,0 +1,36 @@
+package middleware
+
+import "net/http"
+
+func CORS(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ writeMethods(w, r)
+ writeOrigins(w, r)
+ writeHeaders(w)
+ next.ServeHTTP(w, r)
+ })
+}
+
+func writeMethods(responseWriter http.ResponseWriter, request *http.Request) {
+ method := request.Header.Get("Access-Control-Request-Method")
+ if isPreflight(request) {
+ responseWriter.Header().Add("Access-Control-Allow-Methods", method)
+ }
+}
+
+func writeOrigins(responseWriter http.ResponseWriter, request *http.Request) {
+ origin := request.Header.Get("Origin")
+ if origin == "http://localhost:5173" {
+ responseWriter.Header().Add("Access-Control-Allow-Origin", origin)
+ }
+}
+
+func writeHeaders(responseWriter http.ResponseWriter) {
+ responseWriter.Header().Add("Access-Control-Allow-Headers", "*")
+}
+
+func isPreflight(request *http.Request) bool {
+ return request.Method == "OPTIONS" &&
+ request.Header.Get("Origin") != "" &&
+ request.Header.Get("Access-Control-Request-Method") != ""
+}
diff --git a/server/middleware/logging.go b/server/middleware/log.go
index d91a5c8..a872daa 100644
--- a/server/middleware/logging.go
+++ b/server/middleware/log.go
@@ -25,7 +25,13 @@ func (w *LoggingResponseWriter) Write(b []byte) (int, error) {
return w.ResponseWriter.Write(b)
}
-func LoggingMiddleware(next http.Handler) http.Handler {
+func (w *LoggingResponseWriter) Flush() {
+ if f, ok := w.ResponseWriter.(http.Flusher); ok {
+ f.Flush()
+ }
+}
+
+func Log(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()