aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/session.go
blob: 36867b305c02ce19646192323c34e1493f065204 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package database

import (
	"context"
	"crypto/rand"
	"database/sql"
	"encoding/base64"
	"errors"
	"io"
	"sync"

	"github.com/coreos/go-oidc/v3/oidc"
	"golang.org/x/oauth2"
)

type SessionStore interface {
	CreateState(ctx context.Context) (string, error)
	CheckState(ctx context.Context, state string) (bool, error)
	CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error)
	GetSession(ctx context.Context, sessionToken string) (*Session, error)
}

func (d *database) CreateState(ctx context.Context) (string, error) {
	// Generate a random CSRF state token
	tokenBytes := make([]byte, 32)
	if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil {
		return "", err
	}
	token := base64.URLEncoding.EncodeToString(tokenBytes)

	// Insert the state into the database
	_, err := d.exec(ctx, d.db, "sessions/create_state", token)
	if err != nil {
		return "", err
	}

	return token, nil
}

func (d *database) CheckState(ctx context.Context, state string) (bool, error) {
	var exists bool
	row, err := d.queryRow(ctx, d.db, "sessions/check_state", state)
	if err != nil {
		return false, err
	}
	err = row.Scan(&exists)
	if err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return false, nil
		}
		return false, err
	}
	return exists, nil
}

func (d *database) CreateSession(
	ctx context.Context,
	token *oauth2.Token,
	idToken *oidc.IDToken,
) (sessionToken string, err error) {
	// Generate a random session token
	tokenBytes := make([]byte, 32)
	if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil {
		return
	}
	sessionToken = base64.URLEncoding.EncodeToString(tokenBytes)

	// Insert the session into the database
	_, err = d.exec(
		ctx,
		d.db,
		"sessions/create_session",
		sessionToken,
		idToken.Subject,
		token.AccessToken,
		token.Expiry,
	)
	return
}

func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) {
	row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken)
	if err != nil {
		d.logger.ErrorContext(ctx, "Failed to get session", "error", err)
		return nil, err
	}

	var session Session
	err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry)
	if err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return nil, nil
		}
		d.logger.ErrorContext(ctx, "Failed to scan session", "error", err)
		return nil, err
	}

	return &session, nil
}

func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) {
	defer wg.Done()

	result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions")
	if err != nil {
		d.logger.Error("Failed to clean up sessions", "error", err)
		return
	}

	rows, err := result.RowsAffected()
	if err != nil {
		d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err)
		return
	}
	d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows)
}

type Session struct {
	Token      string
	Subject    string
	OAuthToken oauth2.Token
}