Skip to content

feat(hooks): add "portable" generation mode #1850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions packages/plugins/swr/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ function generateModelHooks(
const fileName = paramCase(model.name);
const sf = project.createSourceFile(path.join(outDir, `${fileName}.ts`), undefined, { overwrite: true });

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(outDir, options);
sf.addImportDeclaration({
namedImports: ['Prisma'],
Expand Down Expand Up @@ -261,6 +259,7 @@ function generateIndex(project: Project, outDir: string, models: DataModel[]) {
const sf = project.createSourceFile(path.join(outDir, 'index.ts'), undefined, { overwrite: true });
sf.addStatements(models.map((d) => `export * from './${paramCase(d.name)}';`));
sf.addStatements(`export { Provider } from '@zenstackhq/swr/runtime';`);
sf.addStatements(`export { default as metadata } from './__model_meta';`);
}

function generateQueryHook(
Expand Down
36 changes: 33 additions & 3 deletions packages/plugins/tanstack-query/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import { DataModel, DataModelFieldType, Model, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import { paramCase } from 'change-case';
import fs from 'fs';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
import { Project, SourceFile, VariableDeclarationKind } from 'ts-morph';
Expand Down Expand Up @@ -45,6 +46,14 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
outDir = resolvePath(outDir, options);
ensureEmptyDir(outDir);

if (options.portable && typeof options.portable !== 'boolean') {
throw new PluginError(
name,
`Invalid value for "portable" option: ${options.portable}, a boolean value is expected`
);
}
Comment on lines +49 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure 'portable' option validation handles all non-boolean values

The current validation if (options.portable && typeof options.portable !== 'boolean') may not catch falsey non-boolean values like 0 or ''. To accurately validate the portable option, consider checking if it is defined and not a boolean.

Apply this diff to fix the condition:

-if (options.portable && typeof options.portable !== 'boolean') {
+if (typeof options.portable !== 'undefined' && typeof options.portable !== 'boolean') {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (options.portable && typeof options.portable !== 'boolean') {
throw new PluginError(
name,
`Invalid value for "portable" option: ${options.portable}, a boolean value is expected`
);
}
if (typeof options.portable !== 'undefined' && typeof options.portable !== 'boolean') {
throw new PluginError(
name,
`Invalid value for "portable" option: ${options.portable}, a boolean value is expected`
);
}

const portable = options.portable ?? false;

await generateModelMeta(project, models, typeDefs, {
output: path.join(outDir, '__model_meta.ts'),
generateAttributes: false,
Expand All @@ -61,6 +70,10 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
generateModelHooks(target, version, project, outDir, dataModel, mapping, options);
});

if (portable) {
generateBundledTypes(project, outDir, options);
}

await saveProject(project);
return { warnings };
}
Expand Down Expand Up @@ -333,9 +346,7 @@ function generateModelHooks(
const fileName = paramCase(model.name);
const sf = project.createSourceFile(path.join(outDir, `${fileName}.ts`), undefined, { overwrite: true });

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(outDir, options);
const prismaImport = options.portable ? './__types' : getPrismaClientImportSpec(outDir, options);
sf.addImportDeclaration({
namedImports: ['Prisma', model.name],
isTypeOnly: true,
Expand Down Expand Up @@ -584,6 +595,7 @@ function generateIndex(
sf.addStatements(`export { SvelteQueryContextKey, setHooksContext } from '${runtimeImportBase}/svelte';`);
break;
}
sf.addStatements(`export { default as metadata } from './__model_meta';`);
}

function makeGetContext(target: TargetFramework) {
Expand Down Expand Up @@ -724,3 +736,21 @@ function makeMutationOptions(target: string, returnType: string, argsType: strin
function makeRuntimeImportBase(version: TanStackVersion) {
return `@zenstackhq/tanstack-query/runtime${version === 'v5' ? '-v5' : ''}`;
}

function generateBundledTypes(project: Project, outDir: string, options: PluginOptions) {
if (!options.prismaClientDtsPath) {
throw new PluginError(name, `Unable to determine the location of PrismaClient types`);
}

// copy PrismaClient index.d.ts
const content = fs.readFileSync(options.prismaClientDtsPath, 'utf-8');
project.createSourceFile(path.join(outDir, '__types.d.ts'), content, { overwrite: true });

// "runtime/library.d.ts" is referenced by Prisma's DTS, and it's generated into Prisma's output
// folder if a custom output is specified; if not, it's referenced from '@prisma/client'
const libraryDts = path.join(path.dirname(options.prismaClientDtsPath), 'runtime', 'library.d.ts');
if (fs.existsSync(libraryDts)) {
const content = fs.readFileSync(libraryDts, 'utf-8');
project.createSourceFile(path.join(outDir, 'runtime', 'library.d.ts'), content, { overwrite: true });
}
}
153 changes: 153 additions & 0 deletions packages/plugins/tanstack-query/tests/portable.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/// <reference types="@types/jest" />

import { loadSchema, normalizePath } from '@zenstackhq/testtools';
import path from 'path';
import tmp from 'tmp';

describe('Tanstack Query Plugin Portable Tests', () => {
it('supports portable for standard prisma client', async () => {
await loadSchema(
`
plugin tanstack {
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
output = '$projectRoot/hooks'
target = 'react'
portable = true
}

model User {
id Int @id @default(autoincrement())
email String
posts Post[]
}

model Post {
id Int @id @default(autoincrement())
title String
author User @relation(fields: [authorId], references: [id])
authorId Int
}
`,
{
provider: 'postgresql',
pushDb: false,
extraDependencies: ['[email protected]', '@types/[email protected]', '@tanstack/[email protected]'],
copyDependencies: [path.resolve(__dirname, '../dist')],
compile: true,
extraSourceFiles: [
{
name: 'main.ts',
content: `
import { useFindUniqueUser } from './hooks';
const { data } = useFindUniqueUser({ where: { id: 1 }, include: { posts: true } });
console.log(data?.email);
console.log(data?.posts[0].title);
`,
},
],
}
);
});

it('supports portable for custom prisma client output', async () => {
const t = tmp.dirSync({ unsafeCleanup: true });
const projectDir = t.name;

await loadSchema(
`
datasource db {
provider = 'postgresql'
url = env('DATABASE_URL')
}

generator client {
provider = 'prisma-client-js'
output = '$projectRoot/myprisma'
}

plugin tanstack {
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
output = '$projectRoot/hooks'
target = 'react'
portable = true
}

model User {
id Int @id @default(autoincrement())
email String
posts Post[]
}

model Post {
id Int @id @default(autoincrement())
title String
author User @relation(fields: [authorId], references: [id])
authorId Int
}
`,
{
provider: 'postgresql',
pushDb: false,
extraDependencies: ['[email protected]', '@types/[email protected]', '@tanstack/[email protected]'],
copyDependencies: [path.resolve(__dirname, '../dist')],
compile: true,
addPrelude: false,
projectDir,
prismaLoadPath: `${projectDir}/myprisma`,
extraSourceFiles: [
{
name: 'main.ts',
content: `
import { useFindUniqueUser } from './hooks';
const { data } = useFindUniqueUser({ where: { id: 1 }, include: { posts: true } });
console.log(data?.email);
console.log(data?.posts[0].title);
`,
},
],
}
);
});

it('supports portable for logical client', async () => {
await loadSchema(
`
plugin tanstack {
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
output = '$projectRoot/hooks'
target = 'react'
portable = true
}

model Base {
id Int @id @default(autoincrement())
createdAt DateTime @default(now())
type String
@@delegate(type)
}

model User extends Base {
email String
}
`,
{
provider: 'postgresql',
pushDb: false,
extraDependencies: ['[email protected]', '@types/[email protected]', '@tanstack/[email protected]'],
copyDependencies: [path.resolve(__dirname, '../dist')],
compile: true,
extraSourceFiles: [
{
name: 'main.ts',
content: `
import { useFindUniqueUser } from './hooks';
const { data } = useFindUniqueUser({ where: { id: 1 } });
console.log(data?.email);
console.log(data?.createdAt);
`,
},
],
}
);
});
});
2 changes: 0 additions & 2 deletions packages/plugins/trpc/src/client-helper/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ export function generateClientTypingForModel(
}
);

sf.addStatements([`/* eslint-disable */`]);

generateImports(clientType, sf, options, version);

// generate a `ClientType` interface that contains typing for query/mutation operations
Expand Down
5 changes: 0 additions & 5 deletions packages/plugins/trpc/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ function createAppRouter(
overwrite: true,
});

appRouter.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(path.dirname(indexFile), options);

if (version === 'v10') {
Expand Down Expand Up @@ -274,8 +272,6 @@ function generateModelCreateRouter(
overwrite: true,
});

modelRouter.addStatements('/* eslint-disable */');

if (version === 'v10') {
modelRouter.addImportDeclarations([
{
Expand Down Expand Up @@ -386,7 +382,6 @@ function createHelper(outDir: string) {
overwrite: true,
});

sf.addStatements('/* eslint-disable */');
sf.addStatements(`import { TRPCError } from '@trpc/server';`);
sf.addStatements(`import { isPrismaClientKnownRequestError } from '${RUNTIME_PACKAGE}';`);

Expand Down
7 changes: 5 additions & 2 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ export class PluginRunner {
let dmmf: DMMF.Document | undefined = undefined;
let shortNameMap: Map<string, string> | undefined;
let prismaClientPath = '@prisma/client';
let prismaClientDtsPath: string | undefined = undefined;

const project = createProject();
for (const { name, description, run, options: pluginOptions } of corePlugins) {
const options = { ...pluginOptions, prismaClientPath };
Expand Down Expand Up @@ -165,6 +167,7 @@ export class PluginRunner {
if (r.prismaClientPath) {
// use the prisma client path returned by the plugin
prismaClientPath = r.prismaClientPath;
prismaClientDtsPath = r.prismaClientDtsPath;
}
}

Expand All @@ -173,13 +176,13 @@ export class PluginRunner {

// run user plugins
for (const { name, description, run, options: pluginOptions } of userPlugins) {
const options = { ...pluginOptions, prismaClientPath };
const options = { ...pluginOptions, prismaClientPath, prismaClientDtsPath };
const r = await this.runPlugin(
name,
description,
run,
runnerOptions,
options,
options as PluginOptions,
dmmf,
shortNameMap,
project,
Expand Down
9 changes: 7 additions & 2 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export class EnhancerGenerator {
private readonly outDir: string
) {}

async generate(): Promise<{ dmmf: DMMF.Document | undefined }> {
async generate(): Promise<{ dmmf: DMMF.Document | undefined; newPrismaClientDtsPath: string | undefined }> {
let dmmf: DMMF.Document | undefined;

const prismaImport = getPrismaClientImportSpec(this.outDir, this.options);
Expand Down Expand Up @@ -128,7 +128,12 @@ ${
await this.saveSourceFile(enhanceTs);
}

return { dmmf };
return {
dmmf,
newPrismaClientDtsPath: prismaTypesFixed
? path.resolve(this.outDir, LOGICAL_CLIENT_GENERATION_PATH, 'index-fixed.d.ts')
: undefined,
};
}

private getZodImport() {
Expand Down
4 changes: 2 additions & 2 deletions packages/schema/src/plugins/enhancer/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const run: PluginFunction = async (model, options, _dmmf, globalOptions) => {

await generateModelMeta(model, options, project, outDir);
await generatePolicy(model, options, project, outDir);
const { dmmf } = await new EnhancerGenerator(model, options, project, outDir).generate();
const { dmmf, newPrismaClientDtsPath } = await new EnhancerGenerator(model, options, project, outDir).generate();

let prismaClientPath: string | undefined;
if (dmmf) {
Expand All @@ -44,7 +44,7 @@ const run: PluginFunction = async (model, options, _dmmf, globalOptions) => {
}
}

return { dmmf, warnings: [], prismaClientPath };
return { dmmf, warnings: [], prismaClientPath, prismaClientDtsPath: newPrismaClientDtsPath };
};

export default run;
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ export class PolicyGenerator {

async generate(project: Project, model: Model, output: string) {
const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true });
sf.addStatements('/* eslint-disable */');

this.writeImports(model, output, sf);

Expand Down
Loading
Loading