summaryrefslogtreecommitdiff
path: root/server/guest
diff options
context:
space:
mode:
Diffstat (limited to 'server/guest')
-rw-r--r--server/guest/handler.go268
-rw-r--r--server/guest/store.go2
2 files changed, 269 insertions, 1 deletions
diff --git a/server/guest/handler.go b/server/guest/handler.go
new file mode 100644
index 0000000..46b8a45
--- /dev/null
+++ b/server/guest/handler.go
@@ -0,0 +1,268 @@
+package guest
+
+import (
+ "encoding/json"
+ "errors"
+ "net/http"
+ "os"
+ "regexp"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
+
+var (
+ guestRe = regexp.MustCompile(`^/guest/*$`)
+ guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`)
+)
+
+type GuestHandler struct {
+ store guestStore
+}
+
+type guestStore interface {
+ Find(credentials Credentials) (Guest, error)
+ Get() ([]Guest, error)
+ Add(guest Guest) error
+ Update(guest Guest) error
+}
+
+type appError struct {
+ Error error
+ Message string
+ Code int
+}
+
+func NewGuestHandler(s guestStore) *GuestHandler {
+ return &GuestHandler{
+ store: s,
+ }
+}
+
+func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
+ switch {
+ case request.Method == http.MethodOptions:
+ responseWriter.WriteHeader(http.StatusOK)
+ case request.Method == http.MethodPost && request.URL.Path == "/guest/login":
+ guestHandler.handleLogIn(responseWriter, request)
+ case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path):
+ guestHandler.handlePut(responseWriter, request)
+ case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path):
+ guestHandler.handleGet(responseWriter, request)
+ case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path):
+ guestHandler.handlePost(responseWriter, request)
+ default:
+ responseWriter.WriteHeader(http.StatusNotFound)
+ }
+}
+
+func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) {
+ token, err := guestHandler.logInGuest(request)
+ if err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.Write(token)
+ }
+}
+
+func (guestHandler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) {
+ if err := guestHandler.putGuest(request); err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.WriteHeader(http.StatusOK)
+ }
+}
+
+func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) {
+ guests, err := guestHandler.getGuests(request)
+ if err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.Write(guests)
+ }
+}
+
+func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) {
+ if err := guestHandler.postGuest(request); err != nil {
+ http.Error(responseWriter, err.Message, err.Code)
+ } else {
+ responseWriter.WriteHeader(http.StatusOK)
+ }
+}
+
+func (guestHandler *GuestHandler) logInGuest(request *http.Request) ([]byte, *appError) {
+ credentials, err := guestHandler.decodeCredentials(request)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest}
+ }
+ guest, err := guestHandler.store.Find(credentials)
+ if err != nil {
+ return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized}
+ }
+ expirationTime := guestHandler.setExpirationTime()
+ claims := guestHandler.createClaims(credentials, expirationTime)
+ key, err := guestHandler.readKey()
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError}
+ }
+ token, err := guestHandler.createToken(claims, key)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError}
+ }
+ jsonBytes, err := guestHandler.marshalResponse(guest, token)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError}
+ }
+ return jsonBytes, nil
+}
+
+func (guestHandler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) {
+ var credentials Credentials
+ err := json.NewDecoder(request.Body).Decode(&credentials)
+ defer request.Body.Close()
+ return credentials, err
+}
+
+func (guestHandler *GuestHandler) setExpirationTime() time.Time {
+ return time.Now().Add(15 * time.Minute)
+}
+
+func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims {
+ return &Claims{
+ Credentials: credentials,
+ RegisteredClaims: jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(expirationTime),
+ },
+ }
+}
+
+func (guestHandler *GuestHandler) readKey() ([]byte, error) {
+ // TODO: use properties file
+ return os.ReadFile("C:\\Users\\mhunt\\skey.pem")
+}
+
+func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ return token.SignedString(key)
+}
+
+func (guestHandler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) {
+ loginResponse := guestHandler.createLoginResponse(guest, token)
+ return json.Marshal(loginResponse)
+}
+
+func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse {
+ return &LoginResponse{
+ Guest: weddingGuest,
+ Token: token,
+ }
+}
+
+func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError {
+ if err := guestHandler.validateToken(request); err != nil {
+ return err
+ }
+ if guestHandler.findId(request) {
+ return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound}
+ }
+ guest, err := guestHandler.decodeGuest(request)
+ if err != nil {
+ return &appError{err, "Invalid guest", http.StatusBadRequest}
+ }
+ if err := guestHandler.store.Update(guest); err != nil {
+ return &appError{err, "Failed to update guest", http.StatusInternalServerError}
+ }
+ return nil
+}
+
+func (guestHandler *GuestHandler) validateToken(request *http.Request) *appError {
+ authorizationHeader := guestHandler.getToken(request)
+ claims := guestHandler.initializeClaims()
+ key, err := guestHandler.readKey()
+ if err != nil {
+ return &appError{err, "Failed to read secret key", http.StatusInternalServerError}
+ }
+ token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key)
+ if err != nil {
+ if err == jwt.ErrSignatureInvalid {
+ return &appError{err, "Invalid signature", http.StatusUnauthorized}
+ }
+ return &appError{err, "Failed to parse claims", http.StatusBadRequest}
+ }
+ if !token.Valid {
+ return &appError{err, "Invalid token", http.StatusUnauthorized}
+ }
+ return nil
+}
+
+func (guestHandler *GuestHandler) getToken(request *http.Request) string {
+ return request.Header.Get("Authorization")
+}
+
+func (guestHandler *GuestHandler) initializeClaims() *Claims {
+ return &Claims{}
+}
+
+func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) {
+ return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) {
+ return key, nil
+ })
+}
+
+func (guestHandler *GuestHandler) findId(request *http.Request) bool {
+ matches := guestIdRe.FindStringSubmatch(request.URL.Path)
+ return len(matches) < 2
+}
+
+func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) {
+ var guest Guest
+ err := json.NewDecoder(request.Body).Decode(&guest)
+ defer request.Body.Close()
+ return guest, err
+}
+
+func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) {
+ // TODO: check with admin token
+ if err := guestHandler.validateToken(request); err != nil {
+ return []byte{}, err
+ }
+ guests, err := guestHandler.store.Get()
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError}
+ }
+ jsonBytes, err := json.Marshal(guests)
+ if err != nil {
+ return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError}
+ }
+ return jsonBytes, nil
+}
+
+func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError {
+ if err := guestHandler.validateToken(request); err != nil {
+ return err
+ }
+ guest, err := guestHandler.decodeGuest(request)
+ if err != nil {
+ return &appError{err, "Invalid guest", http.StatusBadRequest}
+ }
+ guests, err := guestHandler.store.Get()
+ if err != nil {
+ return &appError{err, "Failed to get guests", http.StatusInternalServerError}
+ }
+ if err := guestHandler.checkExistingGuests(guests, guest); err != nil {
+ return &appError{err, "ID already exists", http.StatusConflict}
+ }
+ if err := guestHandler.store.Add(guest); err != nil {
+ return &appError{err, "Failed to add guest", http.StatusInternalServerError}
+ }
+ return nil
+}
+
+func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error {
+ for _, guest := range guests {
+ if guest.Id == newGuest.Id {
+ return errors.New("ID already exists")
+ }
+ }
+ return nil
+}
diff --git a/server/guest/store.go b/server/guest/store.go
index db9fc3d..1a07161 100644
--- a/server/guest/store.go
+++ b/server/guest/store.go
@@ -17,7 +17,7 @@ func NewMemStore(db *pgxpool.Pool) *MemStore {
}
}
-func (m MemStore) FindGuest(creds Credentials) (Guest, error) {
+func (m MemStore) Find(creds Credentials) (Guest, error) {
rows, err := m.db.Query(context.Background(), "select * from guest")
var guest Guest
if err != nil {