From 6a86150771858eebedaff32a653b23427c97da7b Mon Sep 17 00:00:00 2001 From: Justin Farrell Date: Wed, 18 Dec 2024 19:47:41 -0500 Subject: [PATCH 1/3] Add session refresh functionality with tests --- internal/auth/sessions.go | 35 +++++++++++ internal/auth/sessions_test.go | 110 +++++++++++++++++++++++++++++++-- 2 files changed, 141 insertions(+), 4 deletions(-) diff --git a/internal/auth/sessions.go b/internal/auth/sessions.go index 04eafa9..4549609 100644 --- a/internal/auth/sessions.go +++ b/internal/auth/sessions.go @@ -141,3 +141,38 @@ func generateSessionID() (string, error) { func (s *Session) IsExpired() bool { return time.Now().After(s.ExpiresAt) } + +// RefreshSession refreshes the session token for a given account ID +// @Summary Refresh a session +// @Description Refresh the session token for a given account ID. Extends the expiration by 12 hours. +// @Tags sessions +// @Accept json +// @Produce json +// @Param account_id path int true "Account ID" +// @Success 200 {object} Session +// @Failure 404 {object} error +// @Router /sessions/refresh/{account_id} [put] +func (s *SessionStore) RefreshSession(accountID int) (*Session, error) { + var session Session + query := `SELECT * FROM sessions WHERE account_id = $1 ORDER BY created_at DESC LIMIT 1` + err := s.DB.Get(&session, query, accountID) + if err != nil { + if err == sql.ErrNoRows { + log.Printf("No session found for account ID %d", accountID) + return nil, errors.New("No session found") + } + log.Printf("Error retrieving session for account ID %d: %v", accountID, err) + return nil, err + } + + session.ExpiresAt = time.Now().Add(12 * time.Hour) + updateQuery := `UPDATE sessions SET expires_at = :expires_at WHERE id = :id` + _, err = s.DB.NamedExec(updateQuery, session) + if err != nil { + log.Printf("Error refreshing session for account ID %d: %v", accountID, err) + return nil, err + } + + log.Printf("Session refreshed: %v", session) + return &session, nil +} diff --git a/internal/auth/sessions_test.go b/internal/auth/sessions_test.go index 933cdcc..98a4373 100644 --- a/internal/auth/sessions_test.go +++ b/internal/auth/sessions_test.go @@ -49,11 +49,11 @@ func TestCreateSession_Error(t *testing.T) { store := NewSessionStore(sqlxDB) accountID := 1 - mock.ExpectExec("INSERT INTO sessions"). - WithArgs(sqlmock.AnyArg(), accountID, sqlmock.AnyArg(), sqlmock.AnyArg()). - WillReturnError(sql.ErrConnDone) + mock.ExpectQuery("SELECT \\* FROM sessions WHERE account_id = \\$1 ORDER BY created_at DESC LIMIT 1"). + WithArgs(accountID). + WillReturnError(sql.ErrNoRows) - _, err = store.CreateSession(accountID) + _, err = store.RefreshSession(accountID) if err == nil { t.Errorf("expected error, got nil") } @@ -216,3 +216,105 @@ func TestIsExpired(t *testing.T) { }) } } +func TestRefreshSession_Success(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewSessionStore(sqlxDB) + + accountID := 1 + sessionID := "test-session-id" + createdAt := time.Now().Add(-1 * time.Hour) + expiresAt := createdAt.Add(12 * time.Hour) + + rows := sqlmock.NewRows([]string{"id", "account_id", "created_at", "expires_at"}). + AddRow(sessionID, accountID, createdAt, expiresAt) + mock.ExpectQuery("SELECT \\* FROM sessions WHERE account_id = \\$1 ORDER BY created_at DESC LIMIT 1"). + WithArgs(accountID). + WillReturnRows(rows) + + mock.ExpectExec("UPDATE sessions SET expires_at = \\? WHERE id = \\?"). + WithArgs(sqlmock.AnyArg(), sessionID). + WillReturnResult(sqlmock.NewResult(1, 1)) + + session, err := store.RefreshSession(accountID) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if session.AccountID != accountID { + t.Errorf("expected account ID %d, got %d", accountID, session.AccountID) + } + + if session.ID != sessionID { + t.Errorf("expected session ID %s, got %s", sessionID, session.ID) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestRefreshSession_NotFound(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewSessionStore(sqlxDB) + + accountID := 1 + mock.ExpectQuery("SELECT \\* FROM sessions WHERE account_id = \\$1 ORDER BY created_at DESC LIMIT 1"). + WithArgs(accountID). + WillReturnError(sql.ErrNoRows) + + _, err = store.RefreshSession(accountID) + if err == nil { + t.Errorf("expected error, got nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestRefreshSession_Error(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewSessionStore(sqlxDB) + + accountID := 1 + sessionID := "test-session-id" + createdAt := time.Now().Add(-1 * time.Hour) + expiresAt := createdAt.Add(12 * time.Hour) + + rows := sqlmock.NewRows([]string{"id", "account_id", "created_at", "expires_at"}). + AddRow(sessionID, accountID, createdAt, expiresAt) + mock.ExpectQuery("SELECT \\* FROM sessions WHERE account_id = \\$1 ORDER BY created_at DESC LIMIT 1"). + WithArgs(accountID). + WillReturnRows(rows) + + mock.ExpectExec("UPDATE sessions SET expires_at = \\? WHERE id = \\?"). + WithArgs(sqlmock.AnyArg(), sessionID). + WillReturnError(sql.ErrConnDone) + + _, err = store.RefreshSession(accountID) + if err == nil { + t.Errorf("expected error, got nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} From 9161d0bfdc19bb789fcc08c1f1b924eaa6e65cd8 Mon Sep 17 00:00:00 2001 From: Justin Farrell Date: Wed, 18 Dec 2024 19:47:52 -0500 Subject: [PATCH 2/3] Add endpoint to refresh session token for a given account ID to docs --- docs/docs.go | 36 ++++++++++++++++++++++++++++++++++++ docs/swagger.json | 36 ++++++++++++++++++++++++++++++++++++ docs/swagger.yaml | 25 +++++++++++++++++++++++++ 3 files changed, 97 insertions(+) diff --git a/docs/docs.go b/docs/docs.go index 7698600..ac5c302 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -498,6 +498,42 @@ const docTemplate = `{ } } }, + "/sessions/refresh/{account_id}": { + "put": { + "description": "Refresh the session token for a given account ID. Extends the expiration by 12 hours.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "sessions" + ], + "summary": "Refresh a session", + "parameters": [ + { + "type": "integer", + "description": "Account ID", + "name": "account_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/auth.Session" + } + }, + "404": { + "description": "Not Found", + "schema": {} + } + } + } + }, "/sessions/{id}": { "get": { "description": "Get a session by its ID", diff --git a/docs/swagger.json b/docs/swagger.json index bd51f9e..b2c3cbf 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -489,6 +489,42 @@ } } }, + "/sessions/refresh/{account_id}": { + "put": { + "description": "Refresh the session token for a given account ID. Extends the expiration by 12 hours.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "sessions" + ], + "summary": "Refresh a session", + "parameters": [ + { + "type": "integer", + "description": "Account ID", + "name": "account_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/auth.Session" + } + }, + "404": { + "description": "Not Found", + "schema": {} + } + } + } + }, "/sessions/{id}": { "get": { "description": "Get a session by its ID", diff --git a/docs/swagger.yaml b/docs/swagger.yaml index a597264..2e5dc6a 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -595,4 +595,29 @@ paths: summary: Get a session tags: - sessions + /sessions/refresh/{account_id}: + put: + consumes: + - application/json + description: Refresh the session token for a given account ID. Extends the expiration + by 12 hours. + parameters: + - description: Account ID + in: path + name: account_id + required: true + type: integer + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/auth.Session' + "404": + description: Not Found + schema: {} + summary: Refresh a session + tags: + - sessions swagger: "2.0" From 39be4927bff383770b8e79fd54fd6a383809e492 Mon Sep 17 00:00:00 2001 From: Justin Farrell Date: Wed, 18 Dec 2024 20:10:52 -0500 Subject: [PATCH 3/3] Add session refresh logic in GetAccount function --- internal/account/get_account.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/internal/account/get_account.go b/internal/account/get_account.go index 9cce84b..0345c6f 100644 --- a/internal/account/get_account.go +++ b/internal/account/get_account.go @@ -80,6 +80,13 @@ func GetAccount(w http.ResponseWriter, r *http.Request, db *sqlx.DB, store *auth return errors.New("session has expired") } + // Refresh the session + _, err = store.RefreshSession(session.AccountID) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return errors.New("an error occurred while refreshing the session: " + err.Error()) + } + var ( name string info string