diff options
author | Michael Hunteman <michael@huntm.net> | 2024-08-18 17:48:37 -0700 |
---|---|---|
committer | Michael Hunteman <michael@huntm.net> | 2024-08-18 17:52:35 -0700 |
commit | 42aa6dc74a736f8c96bd1b1bdf37ad3ff905f08c (patch) | |
tree | 0c5232e33fdef411bcc9dc81396c6a397e04c652 /server/guest | |
parent | f9b3c68e31a90e8091da69b9defcb6641697e9c8 (diff) |
Add integration test
Diffstat (limited to 'server/guest')
-rw-r--r-- | server/guest/handler.go | 268 | ||||
-rw-r--r-- | server/guest/store.go | 2 |
2 files changed, 269 insertions, 1 deletions
diff --git a/server/guest/handler.go b/server/guest/handler.go new file mode 100644 index 0000000..46b8a45 --- /dev/null +++ b/server/guest/handler.go @@ -0,0 +1,268 @@ +package guest + +import ( + "encoding/json" + "errors" + "net/http" + "os" + "regexp" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + guestRe = regexp.MustCompile(`^/guest/*$`) + guestIdRe = regexp.MustCompile(`^/guest/([0-9]+)$`) +) + +type GuestHandler struct { + store guestStore +} + +type guestStore interface { + Find(credentials Credentials) (Guest, error) + Get() ([]Guest, error) + Add(guest Guest) error + Update(guest Guest) error +} + +type appError struct { + Error error + Message string + Code int +} + +func NewGuestHandler(s guestStore) *GuestHandler { + return &GuestHandler{ + store: s, + } +} + +func (guestHandler *GuestHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + switch { + case request.Method == http.MethodOptions: + responseWriter.WriteHeader(http.StatusOK) + case request.Method == http.MethodPost && request.URL.Path == "/guest/login": + guestHandler.handleLogIn(responseWriter, request) + case request.Method == http.MethodPut && guestIdRe.MatchString(request.URL.Path): + guestHandler.handlePut(responseWriter, request) + case request.Method == http.MethodGet && guestRe.MatchString(request.URL.Path): + guestHandler.handleGet(responseWriter, request) + case request.Method == http.MethodPost && guestRe.MatchString(request.URL.Path): + guestHandler.handlePost(responseWriter, request) + default: + responseWriter.WriteHeader(http.StatusNotFound) + } +} + +func (guestHandler *GuestHandler) handleLogIn(responseWriter http.ResponseWriter, request *http.Request) { + token, err := guestHandler.logInGuest(request) + if err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.Write(token) + } +} + +func (guestHandler *GuestHandler) handlePut(responseWriter http.ResponseWriter, request *http.Request) { + if err := guestHandler.putGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.WriteHeader(http.StatusOK) + } +} + +func (guestHandler *GuestHandler) handleGet(responseWriter http.ResponseWriter, request *http.Request) { + guests, err := guestHandler.getGuests(request) + if err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.Write(guests) + } +} + +func (guestHandler *GuestHandler) handlePost(responseWriter http.ResponseWriter, request *http.Request) { + if err := guestHandler.postGuest(request); err != nil { + http.Error(responseWriter, err.Message, err.Code) + } else { + responseWriter.WriteHeader(http.StatusOK) + } +} + +func (guestHandler *GuestHandler) logInGuest(request *http.Request) ([]byte, *appError) { + credentials, err := guestHandler.decodeCredentials(request) + if err != nil { + return []byte{}, &appError{err, "Failed to unmarshal credentials", http.StatusBadRequest} + } + guest, err := guestHandler.store.Find(credentials) + if err != nil { + return []byte{}, &appError{err, "Guest not found", http.StatusUnauthorized} + } + expirationTime := guestHandler.setExpirationTime() + claims := guestHandler.createClaims(credentials, expirationTime) + key, err := guestHandler.readKey() + if err != nil { + return []byte{}, &appError{err, "Failed to read secret key", http.StatusInternalServerError} + } + token, err := guestHandler.createToken(claims, key) + if err != nil { + return []byte{}, &appError{err, "Failed to create token", http.StatusInternalServerError} + } + jsonBytes, err := guestHandler.marshalResponse(guest, token) + if err != nil { + return []byte{}, &appError{err, "Failed to marshal response", http.StatusInternalServerError} + } + return jsonBytes, nil +} + +func (guestHandler *GuestHandler) decodeCredentials(request *http.Request) (Credentials, error) { + var credentials Credentials + err := json.NewDecoder(request.Body).Decode(&credentials) + defer request.Body.Close() + return credentials, err +} + +func (guestHandler *GuestHandler) setExpirationTime() time.Time { + return time.Now().Add(15 * time.Minute) +} + +func (guestHandler *GuestHandler) createClaims(credentials Credentials, expirationTime time.Time) *Claims { + return &Claims{ + Credentials: credentials, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expirationTime), + }, + } +} + +func (guestHandler *GuestHandler) readKey() ([]byte, error) { + // TODO: use properties file + return os.ReadFile("C:\\Users\\mhunt\\skey.pem") +} + +func (guestHandler *GuestHandler) createToken(claims *Claims, key []byte) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(key) +} + +func (guestHandler *GuestHandler) marshalResponse(guest Guest, token string) ([]byte, error) { + loginResponse := guestHandler.createLoginResponse(guest, token) + return json.Marshal(loginResponse) +} + +func (guestHandler *GuestHandler) createLoginResponse(weddingGuest Guest, token string) *LoginResponse { + return &LoginResponse{ + Guest: weddingGuest, + Token: token, + } +} + +func (guestHandler *GuestHandler) putGuest(request *http.Request) *appError { + if err := guestHandler.validateToken(request); err != nil { + return err + } + if guestHandler.findId(request) { + return &appError{errors.New("ID not found"), "ID not found", http.StatusNotFound} + } + guest, err := guestHandler.decodeGuest(request) + if err != nil { + return &appError{err, "Invalid guest", http.StatusBadRequest} + } + if err := guestHandler.store.Update(guest); err != nil { + return &appError{err, "Failed to update guest", http.StatusInternalServerError} + } + return nil +} + +func (guestHandler *GuestHandler) validateToken(request *http.Request) *appError { + authorizationHeader := guestHandler.getToken(request) + claims := guestHandler.initializeClaims() + key, err := guestHandler.readKey() + if err != nil { + return &appError{err, "Failed to read secret key", http.StatusInternalServerError} + } + token, err := guestHandler.parseWithClaims(authorizationHeader, claims, key) + if err != nil { + if err == jwt.ErrSignatureInvalid { + return &appError{err, "Invalid signature", http.StatusUnauthorized} + } + return &appError{err, "Failed to parse claims", http.StatusBadRequest} + } + if !token.Valid { + return &appError{err, "Invalid token", http.StatusUnauthorized} + } + return nil +} + +func (guestHandler *GuestHandler) getToken(request *http.Request) string { + return request.Header.Get("Authorization") +} + +func (guestHandler *GuestHandler) initializeClaims() *Claims { + return &Claims{} +} + +func (guestHandler *GuestHandler) parseWithClaims(token string, claims *Claims, key []byte) (*jwt.Token, error) { + return jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (any, error) { + return key, nil + }) +} + +func (guestHandler *GuestHandler) findId(request *http.Request) bool { + matches := guestIdRe.FindStringSubmatch(request.URL.Path) + return len(matches) < 2 +} + +func (guestHandler *GuestHandler) decodeGuest(request *http.Request) (Guest, error) { + var guest Guest + err := json.NewDecoder(request.Body).Decode(&guest) + defer request.Body.Close() + return guest, err +} + +func (guestHandler *GuestHandler) getGuests(request *http.Request) ([]byte, *appError) { + // TODO: check with admin token + if err := guestHandler.validateToken(request); err != nil { + return []byte{}, err + } + guests, err := guestHandler.store.Get() + if err != nil { + return []byte{}, &appError{err, "Failed to get guests", http.StatusInternalServerError} + } + jsonBytes, err := json.Marshal(guests) + if err != nil { + return []byte{}, &appError{err, "Failed to marshal guests", http.StatusInternalServerError} + } + return jsonBytes, nil +} + +func (guestHandler *GuestHandler) postGuest(request *http.Request) *appError { + if err := guestHandler.validateToken(request); err != nil { + return err + } + guest, err := guestHandler.decodeGuest(request) + if err != nil { + return &appError{err, "Invalid guest", http.StatusBadRequest} + } + guests, err := guestHandler.store.Get() + if err != nil { + return &appError{err, "Failed to get guests", http.StatusInternalServerError} + } + if err := guestHandler.checkExistingGuests(guests, guest); err != nil { + return &appError{err, "ID already exists", http.StatusConflict} + } + if err := guestHandler.store.Add(guest); err != nil { + return &appError{err, "Failed to add guest", http.StatusInternalServerError} + } + return nil +} + +func (guestHandler *GuestHandler) checkExistingGuests(guests []Guest, newGuest Guest) error { + for _, guest := range guests { + if guest.Id == newGuest.Id { + return errors.New("ID already exists") + } + } + return nil +} diff --git a/server/guest/store.go b/server/guest/store.go index db9fc3d..1a07161 100644 --- a/server/guest/store.go +++ b/server/guest/store.go @@ -17,7 +17,7 @@ func NewMemStore(db *pgxpool.Pool) *MemStore { } } -func (m MemStore) FindGuest(creds Credentials) (Guest, error) { +func (m MemStore) Find(creds Credentials) (Guest, error) { rows, err := m.db.Query(context.Background(), "select * from guest") var guest Guest if err != nil { |