diff options
Diffstat (limited to 'packages/db/src/core/integration/vite-plugin-db.ts')
-rw-r--r-- | packages/db/src/core/integration/vite-plugin-db.ts | 80 |
1 files changed, 54 insertions, 26 deletions
diff --git a/packages/db/src/core/integration/vite-plugin-db.ts b/packages/db/src/core/integration/vite-plugin-db.ts index ec512962d..c7e922e7b 100644 --- a/packages/db/src/core/integration/vite-plugin-db.ts +++ b/packages/db/src/core/integration/vite-plugin-db.ts @@ -1,15 +1,24 @@ import { fileURLToPath } from 'node:url'; import { normalizePath } from 'vite'; -import { SEED_DEV_FILE_NAME } from '../../runtime/queries.js'; +import { + SEED_DEV_FILE_NAME, + getCreateIndexQueries, + getCreateTableQuery, +} from '../../runtime/queries.js'; import { DB_PATH, RUNTIME_CONFIG_IMPORT, RUNTIME_IMPORT, VIRTUAL_MODULE_ID } from '../consts.js'; import type { DBTables } from '../types.js'; import { type VitePlugin, getDbDirectoryUrl, getRemoteDatabaseUrl } from '../utils.js'; +import { createLocalDatabaseClient } from '../../runtime/db-client.js'; +import { type SQL, sql } from 'drizzle-orm'; +import type { SqliteDB } from '../../runtime/index.js'; +import { SQLiteAsyncDialect } from 'drizzle-orm/sqlite-core'; -const LOCAL_DB_VIRTUAL_MODULE_ID = 'astro:local'; +const WITH_SEED_VIRTUAL_MODULE_ID = 'astro:db:seed'; -const resolvedVirtualModuleId = '\0' + VIRTUAL_MODULE_ID; -const resolvedLocalDbVirtualModuleId = LOCAL_DB_VIRTUAL_MODULE_ID + '/local-db'; -const resolvedSeedVirtualModuleId = '\0' + VIRTUAL_MODULE_ID + '?shouldSeed'; +const resolved = { + virtual: '\0' + VIRTUAL_MODULE_ID, + seedVirtual: '\0' + WITH_SEED_VIRTUAL_MODULE_ID, +}; export type LateTables = { get: () => DBTables; @@ -32,34 +41,36 @@ type VitePluginDBParams = export function vitePluginDb(params: VitePluginDBParams): VitePlugin { const srcDirPath = normalizePath(fileURLToPath(params.srcDir)); + const seedFilePaths = SEED_DEV_FILE_NAME.map((name) => + normalizePath(fileURLToPath(new URL(name, getDbDirectoryUrl(params.root)))) + ); return { name: 'astro:db', enforce: 'pre', async resolveId(id, rawImporter) { - if (id === LOCAL_DB_VIRTUAL_MODULE_ID) return resolvedLocalDbVirtualModuleId; if (id !== VIRTUAL_MODULE_ID) return; - if (params.connectToStudio) return resolvedVirtualModuleId; + if (params.connectToStudio) return resolved.virtual; const importer = rawImporter ? await this.resolve(rawImporter) : null; - if (!importer) return resolvedVirtualModuleId; + if (!importer) return resolved.virtual; if (importer.id.startsWith(srcDirPath)) { // Seed only if the importer is in the src directory. // Otherwise, we may get recursive seed calls (ex. import from db/seed.ts). - return resolvedSeedVirtualModuleId; + return resolved.seedVirtual; } - return resolvedVirtualModuleId; + return resolved.virtual; }, - load(id) { - if (id === resolvedLocalDbVirtualModuleId) { - const dbUrl = new URL(DB_PATH, params.root); - return `import { createLocalDatabaseClient } from ${RUNTIME_IMPORT}; - const dbUrl = ${JSON.stringify(dbUrl)}; - - export const db = createLocalDatabaseClient({ dbUrl });`; + async load(id) { + // Recreate tables whenever a seed file is loaded. + if (seedFilePaths.some((f) => id === f)) { + await recreateTables({ + db: createLocalDatabaseClient({ dbUrl: new URL(DB_PATH, params.root).href }), + tables: params.tables.get(), + }); } - if (id !== resolvedVirtualModuleId && id !== resolvedSeedVirtualModuleId) return; + if (id !== resolved.virtual && id !== resolved.seedVirtual) return; if (params.connectToStudio) { return getStudioVirtualModContents({ @@ -70,7 +81,7 @@ export function vitePluginDb(params: VitePluginDBParams): VitePlugin { return getLocalVirtualModContents({ root: params.root, tables: params.tables.get(), - shouldSeed: id === resolvedSeedVirtualModuleId, + shouldSeed: id === resolved.seedVirtual, }); }, }; @@ -82,6 +93,7 @@ export function getConfigVirtualModContents() { export function getLocalVirtualModContents({ tables, + root, shouldSeed, }: { tables: DBTables; @@ -94,19 +106,19 @@ export function getLocalVirtualModContents({ (name) => new URL(name, getDbDirectoryUrl('file:///')).pathname ); + const dbUrl = new URL(DB_PATH, root); return ` -import { asDrizzleTable, seedLocal } from ${RUNTIME_IMPORT}; -import { db as _db } from ${JSON.stringify(LOCAL_DB_VIRTUAL_MODULE_ID)}; +import { asDrizzleTable, createLocalDatabaseClient } from ${RUNTIME_IMPORT}; +${shouldSeed ? `import { seedLocal } from ${RUNTIME_IMPORT};` : ''} -export const db = _db; +const dbUrl = ${JSON.stringify(dbUrl)}; +export const db = createLocalDatabaseClient({ dbUrl }); ${ shouldSeed ? `await seedLocal({ - db: _db, - tables: ${JSON.stringify(tables)}, - fileGlob: import.meta.glob(${JSON.stringify(seedFilePaths)}), -})` + fileGlob: import.meta.glob(${JSON.stringify(seedFilePaths)}, { eager: true }), +});` : '' } @@ -146,3 +158,19 @@ function getStringifiedCollectionExports(tables: DBTables) { ) .join('\n'); } + +const sqlite = new SQLiteAsyncDialect(); + +async function recreateTables({ db, tables }: { db: SqliteDB; tables: DBTables }) { + const setupQueries: SQL[] = []; + for (const [name, table] of Object.entries(tables)) { + const dropQuery = sql.raw(`DROP TABLE IF EXISTS ${sqlite.escapeName(name)}`); + const createQuery = sql.raw(getCreateTableQuery(name, table)); + const indexQueries = getCreateIndexQueries(name, table); + setupQueries.push(dropQuery, createQuery, ...indexQueries.map((s) => sql.raw(s))); + } + await db.batch([ + db.run(sql`pragma defer_foreign_keys=true;`), + ...setupQueries.map((q) => db.run(q)), + ]); +} |