aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/session.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/database/session.go')
-rw-r--r--backend/internal/database/session.go122
1 files changed, 122 insertions, 0 deletions
diff --git a/backend/internal/database/session.go b/backend/internal/database/session.go
new file mode 100644
index 0000000..36867b3
--- /dev/null
+++ b/backend/internal/database/session.go
@@ -0,0 +1,122 @@
+package database
+
+import (
+ "context"
+ "crypto/rand"
+ "database/sql"
+ "encoding/base64"
+ "errors"
+ "io"
+ "sync"
+
+ "github.com/coreos/go-oidc/v3/oidc"
+ "golang.org/x/oauth2"
+)
+
+type SessionStore interface {
+ CreateState(ctx context.Context) (string, error)
+ CheckState(ctx context.Context, state string) (bool, error)
+ CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error)
+ GetSession(ctx context.Context, sessionToken string) (*Session, error)
+}
+
+func (d *database) CreateState(ctx context.Context) (string, error) {
+ // Generate a random CSRF state token
+ tokenBytes := make([]byte, 32)
+ if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil {
+ return "", err
+ }
+ token := base64.URLEncoding.EncodeToString(tokenBytes)
+
+ // Insert the state into the database
+ _, err := d.exec(ctx, d.db, "sessions/create_state", token)
+ if err != nil {
+ return "", err
+ }
+
+ return token, nil
+}
+
+func (d *database) CheckState(ctx context.Context, state string) (bool, error) {
+ var exists bool
+ row, err := d.queryRow(ctx, d.db, "sessions/check_state", state)
+ if err != nil {
+ return false, err
+ }
+ err = row.Scan(&exists)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
+ return false, err
+ }
+ return exists, nil
+}
+
+func (d *database) CreateSession(
+ ctx context.Context,
+ token *oauth2.Token,
+ idToken *oidc.IDToken,
+) (sessionToken string, err error) {
+ // Generate a random session token
+ tokenBytes := make([]byte, 32)
+ if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil {
+ return
+ }
+ sessionToken = base64.URLEncoding.EncodeToString(tokenBytes)
+
+ // Insert the session into the database
+ _, err = d.exec(
+ ctx,
+ d.db,
+ "sessions/create_session",
+ sessionToken,
+ idToken.Subject,
+ token.AccessToken,
+ token.Expiry,
+ )
+ return
+}
+
+func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) {
+ row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken)
+ if err != nil {
+ d.logger.ErrorContext(ctx, "Failed to get session", "error", err)
+ return nil, err
+ }
+
+ var session Session
+ err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ d.logger.ErrorContext(ctx, "Failed to scan session", "error", err)
+ return nil, err
+ }
+
+ return &session, nil
+}
+
+func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) {
+ defer wg.Done()
+
+ result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions")
+ if err != nil {
+ d.logger.Error("Failed to clean up sessions", "error", err)
+ return
+ }
+
+ rows, err := result.RowsAffected()
+ if err != nil {
+ d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err)
+ return
+ }
+ d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows)
+}
+
+type Session struct {
+ Token string
+ Subject string
+ OAuthToken oauth2.Token
+}