package database
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"log/slog"
"sync"
"time"
"ibd-trader/db"
"ibd-trader/internal/keys"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "github.com/lib/pq"
)
type Database interface {
io.Closer
UserStore
CookieStore
KeyStore
SessionStore
StockStore
driver.Pinger
Migrate(ctx context.Context) error
Maintenance(ctx context.Context)
}
type database struct {
logger *slog.Logger
db *sql.DB
url string
kms keys.KeyManagementService
keyName string
}
func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) {
sqlDB, err := sql.Open("postgres", url)
if err != nil {
return nil, err
}
err = sqlDB.PingContext(ctx)
if err != nil {
// Ping failed. Don't error, but give a warning.
logger.WarnContext(ctx, "Unable to ping database", "error", err)
}
return &database{
logger: logger,
db: sqlDB,
url: url,
kms: kms,
keyName: keyName,
}, nil
}
func (d *database) Close() error {
return d.db.Close()
}
func (d *database) Migrate(ctx context.Context) error {
fs, err := iofs.New(db.Migrations, "migrations")
if err != nil {
return err
}
m, err := migrate.NewWithSourceInstance("iofs", fs, d.url)
if err != nil {
return err
}
d.logger.InfoContext(ctx, "Running DB migration")
err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
d.logger.ErrorContext(ctx, "DB migration failed", "error", err)
return err
}
return nil
}
func (d *database) Maintenance(ctx context.Context) {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
func() {
var wg sync.WaitGroup
wg.Add(1)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
go d.cleanupSessions(ctx, &wg)
wg.Wait()
}()
case <-ctx.Done():
return
}
}
}
func (d *database) Ping(ctx context.Context) error {
return d.db.PingContext(ctx)
}
func (d *database) execInternal(ctx context.Context, queryName string, fn func(string) (any, error)) (any, error) {
query, err := db.GetQuery(queryName)
if err != nil {
return nil, fmt.Errorf("unable to get query: %w", err)
}
d.logger.DebugContext(ctx, "Executing query", "name", queryName, "query", query)
now := time.Now()
// Execute the query
result, err := fn(query)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
d.logger.DebugContext(ctx, "Query executed successfully", "name", queryName, "duration", time.Since(now))
return result, nil
}
func (d *database) exec(ctx context.Context, exec executor, queryName string, args ...any) (sql.Result, error) {
ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
return exec.ExecContext(ctx, query, args...)
})
if err != nil {
return nil, err
} else {
return ret.(sql.Result), nil
}
}
func (d *database) query(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Rows, error) {
ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
return exec.QueryContext(ctx, query, args...)
})
if err != nil {
return nil, err
} else {
return ret.(*sql.Rows), nil
}
}
func (d *database) queryRow(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Row, error) {
ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
return exec.QueryRowContext(ctx, query, args...), nil
})
if err != nil {
return nil, err
} else {
return ret.(*sql.Row), nil
}
}
type executor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
option value='examples/framework-svelte'>examples/framework-svelte
Unnamed repository; edit this file 'description' to name the repository. | |