aboutsummaryrefslogtreecommitdiff
path: root/packages/db/src/core/integration/vite-plugin-db.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/db/src/core/integration/vite-plugin-db.ts')
-rw-r--r--packages/db/src/core/integration/vite-plugin-db.ts80
1 files changed, 54 insertions, 26 deletions
diff --git a/packages/db/src/core/integration/vite-plugin-db.ts b/packages/db/src/core/integration/vite-plugin-db.ts
index ec512962d..c7e922e7b 100644
--- a/packages/db/src/core/integration/vite-plugin-db.ts
+++ b/packages/db/src/core/integration/vite-plugin-db.ts
@@ -1,15 +1,24 @@
import { fileURLToPath } from 'node:url';
import { normalizePath } from 'vite';
-import { SEED_DEV_FILE_NAME } from '../../runtime/queries.js';
+import {
+ SEED_DEV_FILE_NAME,
+ getCreateIndexQueries,
+ getCreateTableQuery,
+} from '../../runtime/queries.js';
import { DB_PATH, RUNTIME_CONFIG_IMPORT, RUNTIME_IMPORT, VIRTUAL_MODULE_ID } from '../consts.js';
import type { DBTables } from '../types.js';
import { type VitePlugin, getDbDirectoryUrl, getRemoteDatabaseUrl } from '../utils.js';
+import { createLocalDatabaseClient } from '../../runtime/db-client.js';
+import { type SQL, sql } from 'drizzle-orm';
+import type { SqliteDB } from '../../runtime/index.js';
+import { SQLiteAsyncDialect } from 'drizzle-orm/sqlite-core';
-const LOCAL_DB_VIRTUAL_MODULE_ID = 'astro:local';
+const WITH_SEED_VIRTUAL_MODULE_ID = 'astro:db:seed';
-const resolvedVirtualModuleId = '\0' + VIRTUAL_MODULE_ID;
-const resolvedLocalDbVirtualModuleId = LOCAL_DB_VIRTUAL_MODULE_ID + '/local-db';
-const resolvedSeedVirtualModuleId = '\0' + VIRTUAL_MODULE_ID + '?shouldSeed';
+const resolved = {
+ virtual: '\0' + VIRTUAL_MODULE_ID,
+ seedVirtual: '\0' + WITH_SEED_VIRTUAL_MODULE_ID,
+};
export type LateTables = {
get: () => DBTables;
@@ -32,34 +41,36 @@ type VitePluginDBParams =
export function vitePluginDb(params: VitePluginDBParams): VitePlugin {
const srcDirPath = normalizePath(fileURLToPath(params.srcDir));
+ const seedFilePaths = SEED_DEV_FILE_NAME.map((name) =>
+ normalizePath(fileURLToPath(new URL(name, getDbDirectoryUrl(params.root))))
+ );
return {
name: 'astro:db',
enforce: 'pre',
async resolveId(id, rawImporter) {
- if (id === LOCAL_DB_VIRTUAL_MODULE_ID) return resolvedLocalDbVirtualModuleId;
if (id !== VIRTUAL_MODULE_ID) return;
- if (params.connectToStudio) return resolvedVirtualModuleId;
+ if (params.connectToStudio) return resolved.virtual;
const importer = rawImporter ? await this.resolve(rawImporter) : null;
- if (!importer) return resolvedVirtualModuleId;
+ if (!importer) return resolved.virtual;
if (importer.id.startsWith(srcDirPath)) {
// Seed only if the importer is in the src directory.
// Otherwise, we may get recursive seed calls (ex. import from db/seed.ts).
- return resolvedSeedVirtualModuleId;
+ return resolved.seedVirtual;
}
- return resolvedVirtualModuleId;
+ return resolved.virtual;
},
- load(id) {
- if (id === resolvedLocalDbVirtualModuleId) {
- const dbUrl = new URL(DB_PATH, params.root);
- return `import { createLocalDatabaseClient } from ${RUNTIME_IMPORT};
- const dbUrl = ${JSON.stringify(dbUrl)};
-
- export const db = createLocalDatabaseClient({ dbUrl });`;
+ async load(id) {
+ // Recreate tables whenever a seed file is loaded.
+ if (seedFilePaths.some((f) => id === f)) {
+ await recreateTables({
+ db: createLocalDatabaseClient({ dbUrl: new URL(DB_PATH, params.root).href }),
+ tables: params.tables.get(),
+ });
}
- if (id !== resolvedVirtualModuleId && id !== resolvedSeedVirtualModuleId) return;
+ if (id !== resolved.virtual && id !== resolved.seedVirtual) return;
if (params.connectToStudio) {
return getStudioVirtualModContents({
@@ -70,7 +81,7 @@ export function vitePluginDb(params: VitePluginDBParams): VitePlugin {
return getLocalVirtualModContents({
root: params.root,
tables: params.tables.get(),
- shouldSeed: id === resolvedSeedVirtualModuleId,
+ shouldSeed: id === resolved.seedVirtual,
});
},
};
@@ -82,6 +93,7 @@ export function getConfigVirtualModContents() {
export function getLocalVirtualModContents({
tables,
+ root,
shouldSeed,
}: {
tables: DBTables;
@@ -94,19 +106,19 @@ export function getLocalVirtualModContents({
(name) => new URL(name, getDbDirectoryUrl('file:///')).pathname
);
+ const dbUrl = new URL(DB_PATH, root);
return `
-import { asDrizzleTable, seedLocal } from ${RUNTIME_IMPORT};
-import { db as _db } from ${JSON.stringify(LOCAL_DB_VIRTUAL_MODULE_ID)};
+import { asDrizzleTable, createLocalDatabaseClient } from ${RUNTIME_IMPORT};
+${shouldSeed ? `import { seedLocal } from ${RUNTIME_IMPORT};` : ''}
-export const db = _db;
+const dbUrl = ${JSON.stringify(dbUrl)};
+export const db = createLocalDatabaseClient({ dbUrl });
${
shouldSeed
? `await seedLocal({
- db: _db,
- tables: ${JSON.stringify(tables)},
- fileGlob: import.meta.glob(${JSON.stringify(seedFilePaths)}),
-})`
+ fileGlob: import.meta.glob(${JSON.stringify(seedFilePaths)}, { eager: true }),
+});`
: ''
}
@@ -146,3 +158,19 @@ function getStringifiedCollectionExports(tables: DBTables) {
)
.join('\n');
}
+
+const sqlite = new SQLiteAsyncDialect();
+
+async function recreateTables({ db, tables }: { db: SqliteDB; tables: DBTables }) {
+ const setupQueries: SQL[] = [];
+ for (const [name, table] of Object.entries(tables)) {
+ const dropQuery = sql.raw(`DROP TABLE IF EXISTS ${sqlite.escapeName(name)}`);
+ const createQuery = sql.raw(getCreateTableQuery(name, table));
+ const indexQueries = getCreateIndexQueries(name, table);
+ setupQueries.push(dropQuery, createQuery, ...indexQueries.map((s) => sql.raw(s)));
+ }
+ await db.batch([
+ db.run(sql`pragma defer_foreign_keys=true;`),
+ ...setupQueries.map((q) => db.run(q)),
+ ]);
+}