package database import ( "context" "database/sql" "database/sql/driver" "errors" "io" "log/slog" "sync" "time" "github.com/ansg191/ibd-trader-backend/db" "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/source/iofs" _ "github.com/lib/pq" ) type Database interface { io.Closer TransactionExecutor driver.Pinger Migrate(ctx context.Context) error Maintenance(ctx context.Context) } type database struct { logger *slog.Logger db *sql.DB url string kms keys.KeyManagementService keyName string } func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) { sqlDB, err := sql.Open("postgres", url) if err != nil { return nil, err } err = sqlDB.PingContext(ctx) if err != nil { // Ping failed. Don't error, but give a warning. logger.WarnContext(ctx, "Unable to ping database", "error", err) } return &database{ logger: logger, db: sqlDB, url: url, kms: kms, keyName: keyName, }, nil } func (d *database) Close() error { return d.db.Close() } func (d *database) Migrate(ctx context.Context) error { return Migrate(ctx, d.url) } func (d *database) Maintenance(ctx context.Context) { ticker := time.NewTicker(15 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: func() { var wg sync.WaitGroup wg.Add(1) _, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() wg.Wait() }() case <-ctx.Done(): return } } } func Migrate(ctx context.Context, url string) error { fs, err := iofs.New(db.Migrations, "migrations") if err != nil { return err } m, err := migrate.NewWithSourceInstance("iofs", fs, url) if err != nil { return err } slog.InfoContext(ctx, "Running DB migration") err = m.Up() if err != nil && !errors.Is(err, migrate.ErrNoChange) { slog.ErrorContext(ctx, "DB migration failed", "error", err) return err } return nil } func (d *database) Ping(ctx context.Context) error { return d.db.PingContext(ctx) } type Executor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } type TransactionExecutor interface { Executor BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { d.logger.DebugContext(ctx, "Executing query", "query", query) now := time.Now() ret, err := d.db.ExecContext(ctx, query, args...) if err != nil { return nil, err } d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) return ret, nil } func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { d.logger.DebugContext(ctx, "Executing query", "query", query) now := time.Now() ret, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) return ret, nil } func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { d.logger.DebugContext(ctx, "Executing query", "query", query) now := time.Now() ret := d.db.QueryRowContext(ctx, query, args...) d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) return ret } func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { return d.db.BeginTx(ctx, opts) }