aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/database.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/database/database.go')
-rw-r--r--backend/internal/database/database.go178
1 files changed, 178 insertions, 0 deletions
diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go
new file mode 100644
index 0000000..4022dde
--- /dev/null
+++ b/backend/internal/database/database.go
@@ -0,0 +1,178 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "sync"
+ "time"
+
+ "ibd-trader/db"
+ "ibd-trader/internal/keys"
+
+ "github.com/golang-migrate/migrate/v4"
+ _ "github.com/golang-migrate/migrate/v4/database/postgres"
+ "github.com/golang-migrate/migrate/v4/source/iofs"
+ _ "github.com/lib/pq"
+)
+
+type Database interface {
+ io.Closer
+ UserStore
+ CookieStore
+ KeyStore
+ SessionStore
+ StockStore
+
+ driver.Pinger
+
+ Migrate(ctx context.Context) error
+ Maintenance(ctx context.Context)
+}
+
+type database struct {
+ logger *slog.Logger
+
+ db *sql.DB
+ url string
+
+ kms keys.KeyManagementService
+ keyName string
+}
+
+func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) {
+ sqlDB, err := sql.Open("postgres", url)
+ if err != nil {
+ return nil, err
+ }
+
+ err = sqlDB.PingContext(ctx)
+ if err != nil {
+ // Ping failed. Don't error, but give a warning.
+ logger.WarnContext(ctx, "Unable to ping database", "error", err)
+ }
+
+ return &database{
+ logger: logger,
+ db: sqlDB,
+ url: url,
+ kms: kms,
+ keyName: keyName,
+ }, nil
+}
+
+func (d *database) Close() error {
+ return d.db.Close()
+}
+
+func (d *database) Migrate(ctx context.Context) error {
+ fs, err := iofs.New(db.Migrations, "migrations")
+ if err != nil {
+ return err
+ }
+
+ m, err := migrate.NewWithSourceInstance("iofs", fs, d.url)
+ if err != nil {
+ return err
+ }
+
+ d.logger.InfoContext(ctx, "Running DB migration")
+ err = m.Up()
+ if err != nil && !errors.Is(err, migrate.ErrNoChange) {
+ d.logger.ErrorContext(ctx, "DB migration failed", "error", err)
+ return err
+ }
+
+ return nil
+}
+
+func (d *database) Maintenance(ctx context.Context) {
+ ticker := time.NewTicker(15 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ func() {
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+ defer cancel()
+
+ go d.cleanupSessions(ctx, &wg)
+
+ wg.Wait()
+ }()
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (d *database) Ping(ctx context.Context) error {
+ return d.db.PingContext(ctx)
+}
+
+func (d *database) execInternal(ctx context.Context, queryName string, fn func(string) (any, error)) (any, error) {
+ query, err := db.GetQuery(queryName)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get query: %w", err)
+ }
+ d.logger.DebugContext(ctx, "Executing query", "name", queryName, "query", query)
+
+ now := time.Now()
+
+ // Execute the query
+ result, err := fn(query)
+ if err != nil {
+ return nil, fmt.Errorf("unable to execute query: %w", err)
+ }
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "name", queryName, "duration", time.Since(now))
+
+ return result, nil
+}
+
+func (d *database) exec(ctx context.Context, exec executor, queryName string, args ...any) (sql.Result, error) {
+ ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
+ return exec.ExecContext(ctx, query, args...)
+ })
+ if err != nil {
+ return nil, err
+ } else {
+ return ret.(sql.Result), nil
+ }
+}
+
+func (d *database) query(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Rows, error) {
+ ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
+ return exec.QueryContext(ctx, query, args...)
+ })
+ if err != nil {
+ return nil, err
+ } else {
+ return ret.(*sql.Rows), nil
+ }
+}
+
+func (d *database) queryRow(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Row, error) {
+ ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
+ return exec.QueryRowContext(ctx, query, args...), nil
+ })
+ if err != nil {
+ return nil, err
+ } else {
+ return ret.(*sql.Row), nil
+ }
+}
+
+type executor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
+}