package database import ( "context" "crypto/rand" "database/sql" "encoding/base64" "errors" "io" "sync" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) type SessionStore interface { CreateState(ctx context.Context) (string, error) CheckState(ctx context.Context, state string) (bool, error) CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error) GetSession(ctx context.Context, sessionToken string) (*Session, error) } func (d *database) CreateState(ctx context.Context) (string, error) { // Generate a random CSRF state token tokenBytes := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil { return "", err } token := base64.URLEncoding.EncodeToString(tokenBytes) // Insert the state into the database _, err := d.exec(ctx, d.db, "sessions/create_state", token) if err != nil { return "", err } return token, nil } func (d *database) CheckState(ctx context.Context, state string) (bool, error) { var exists bool row, err := d.queryRow(ctx, d.db, "sessions/check_state", state) if err != nil { return false, err } err = row.Scan(&exists) if err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil } return false, err } return exists, nil } func (d *database) CreateSession( ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken, ) (sessionToken string, err error) { // Generate a random session token tokenBytes := make([]byte, 32) if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil { return } sessionToken = base64.URLEncoding.EncodeToString(tokenBytes) // Insert the session into the database _, err = d.exec( ctx, d.db, "sessions/create_session", sessionToken, idToken.Subject, token.AccessToken, token.Expiry, ) return } func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) { row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken) if err != nil { d.logger.ErrorContext(ctx, "Failed to get session", "error", err) return nil, err } var session Session err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } d.logger.ErrorContext(ctx, "Failed to scan session", "error", err) return nil, err } return &session, nil } func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions") if err != nil { d.logger.Error("Failed to clean up sessions", "error", err) return } rows, err := result.RowsAffected() if err != nil { d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err) return } d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows) } type Session struct { Token string Subject string OAuthToken oauth2.Token }