diff options
Diffstat (limited to 'backend/internal/database/session.go')
-rw-r--r-- | backend/internal/database/session.go | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/backend/internal/database/session.go b/backend/internal/database/session.go new file mode 100644 index 0000000..36867b3 --- /dev/null +++ b/backend/internal/database/session.go @@ -0,0 +1,122 @@ +package database + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "errors" + "io" + "sync" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +type SessionStore interface { + CreateState(ctx context.Context) (string, error) + CheckState(ctx context.Context, state string) (bool, error) + CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error) + GetSession(ctx context.Context, sessionToken string) (*Session, error) +} + +func (d *database) CreateState(ctx context.Context) (string, error) { + // Generate a random CSRF state token + tokenBytes := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil { + return "", err + } + token := base64.URLEncoding.EncodeToString(tokenBytes) + + // Insert the state into the database + _, err := d.exec(ctx, d.db, "sessions/create_state", token) + if err != nil { + return "", err + } + + return token, nil +} + +func (d *database) CheckState(ctx context.Context, state string) (bool, error) { + var exists bool + row, err := d.queryRow(ctx, d.db, "sessions/check_state", state) + if err != nil { + return false, err + } + err = row.Scan(&exists) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + return false, err + } + return exists, nil +} + +func (d *database) CreateSession( + ctx context.Context, + token *oauth2.Token, + idToken *oidc.IDToken, +) (sessionToken string, err error) { + // Generate a random session token + tokenBytes := make([]byte, 32) + if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil { + return + } + sessionToken = base64.URLEncoding.EncodeToString(tokenBytes) + + // Insert the session into the database + _, err = d.exec( + ctx, + d.db, + "sessions/create_session", + sessionToken, + idToken.Subject, + token.AccessToken, + token.Expiry, + ) + return +} + +func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) { + row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken) + if err != nil { + d.logger.ErrorContext(ctx, "Failed to get session", "error", err) + return nil, err + } + + var session Session + err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + d.logger.ErrorContext(ctx, "Failed to scan session", "error", err) + return nil, err + } + + return &session, nil +} + +func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + + result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions") + if err != nil { + d.logger.Error("Failed to clean up sessions", "error", err) + return + } + + rows, err := result.RowsAffected() + if err != nil { + d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err) + return + } + d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows) +} + +type Session struct { + Token string + Subject string + OAuthToken oauth2.Token +} |