Skip to content

fix(zod): zod create/update schemas should exclude discriminator fields #1609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 20 additions & 34 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
getDataModels,
getLiteral,
isDelegateModel,
isDiscriminatorField,
type PluginOptions,
} from '@zenstackhq/sdk';
import {
Expand Down Expand Up @@ -495,33 +496,34 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
return source;
}

private readonly ModelCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/;

private removeDiscriminatorFromConcreteInput(
typeAlias: TypeAliasDeclaration,
delegateInfo: DelegateInfo,
_delegateInfo: DelegateInfo,
source: string
) {
// remove discriminator field from the create/update input of concrete models because
// discriminator cannot be set directly
// remove discriminator field from the create/update input because discriminator cannot be set directly
const typeName = typeAlias.getName();
const concreteModelNames = delegateInfo.map(([, concretes]) => concretes.map((c) => c.name)).flatMap((c) => c);
const concreteCreateUpdateInputRegex = new RegExp(
`(${concreteModelNames.join('|')})(Unchecked)?(Create|Update).*Input`
);

const match = typeName.match(concreteCreateUpdateInputRegex);
const match = typeName.match(this.ModelCreateUpdateInputRegex);
if (match) {
const modelName = match[1];
const record = delegateInfo.find(([, concretes]) => concretes.some((c) => c.name === modelName));
if (record) {
// remove all discriminator fields recursively
const delegateOfConcrete = record[0];
const discriminators = this.getDiscriminatorFieldsRecursively(delegateOfConcrete);
discriminators.forEach((discriminatorDecl) => {
const discriminatorNode = this.findNamedProperty(typeAlias, discriminatorDecl.name);
if (discriminatorNode) {
source = source.replace(discriminatorNode.getText(), '');
const dataModel = this.model.declarations.find(
(d): d is DataModel => isDataModel(d) && d.name === modelName
);

if (!dataModel) {
return source;
}

for (const field of dataModel.fields) {
if (isDiscriminatorField(field)) {
const fieldDef = this.findNamedProperty(typeAlias, field.name);
if (fieldDef) {
source = source.replace(fieldDef.getText(), '');
}
});
}
}
}
return source;
Expand Down Expand Up @@ -618,22 +620,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}

private getDiscriminatorFieldsRecursively(delegate: DataModel, result: DataModelField[] = []) {
if (isDelegateModel(delegate)) {
const discriminator = this.getDiscriminatorField(delegate);
if (discriminator) {
result.push(discriminator);
}

for (const superType of delegate.superTypes) {
if (superType.ref) {
result.push(...this.getDiscriminatorFieldsRecursively(superType.ref, result));
}
}
}
return result;
}

private async saveSourceFile(sf: SourceFile) {
if (this.options.preserveTsFiles) {
await sf.save();
Expand Down
15 changes: 12 additions & 3 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
ensureEmptyDir,
getDataModels,
hasAttribute,
isDiscriminatorField,
isEnumFieldReference,
isForeignKeyField,
isFromStdlib,
Expand Down Expand Up @@ -368,6 +369,13 @@ export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T
);
}

// delegate discriminator fields are to be excluded from mutation schemas
const delegateFields = model.fields.filter((field) => isDiscriminatorField(field));
const omitDiscriminators =
delegateFields.length > 0
? `.omit({ ${delegateFields.map((f) => `${f.name}: true`).join(', ')} })`
: '';

////////////////////////////////////////////////
// 1. Model schema
////////////////////////////////////////////////
Expand Down Expand Up @@ -429,7 +437,7 @@ export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};
////////////////////////////////////////////////

// schema for validating prisma create input (all fields optional)
let prismaCreateSchema = this.makePassthrough(this.makePartial('baseSchema'));
let prismaCreateSchema = this.makePassthrough(this.makePartial(`baseSchema${omitDiscriminators}`));
if (refineFuncName) {
prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`;
}
Expand All @@ -445,6 +453,7 @@ export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSch
// note numeric fields can be simple update or atomic operations
let prismaUpdateSchema = `z.object({
${scalarFields
.filter((f) => !isDiscriminatorField(f))
.map((field) => {
let fieldSchema = makeFieldSchema(field);
if (field.type.type === 'Int' || field.type.type === 'Float') {
Expand Down Expand Up @@ -472,7 +481,7 @@ export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSch
// 3. Create schema
////////////////////////////////////////////////

let createSchema = 'baseSchema';
let createSchema = `baseSchema${omitDiscriminators}`;
const fieldsWithDefault = scalarFields.filter(
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
);
Expand Down Expand Up @@ -524,7 +533,7 @@ export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};
////////////////////////////////////////////////

// for update all fields are optional
let updateSchema = this.makePartial('baseSchema');
let updateSchema = this.makePartial(`baseSchema${omitDiscriminators}`);

// export schema with only scalar fields: `[Model]UpdateScalarSchema`
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;
Expand Down
29 changes: 26 additions & 3 deletions packages/schema/src/plugins/zod/transformer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { indentString, type PluginOptions } from '@zenstackhq/sdk';
import type { Model } from '@zenstackhq/sdk/ast';
import { indentString, isDiscriminatorField, type PluginOptions } from '@zenstackhq/sdk';
import { DataModel, isDataModel, type Model } from '@zenstackhq/sdk/ast';
import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers';
import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
import path from 'path';
Expand Down Expand Up @@ -90,8 +90,31 @@ export default class Transformer {
return `${this.name}.schema`;
}

private delegateCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/;

generateObjectSchemaFields(generateUnchecked: boolean) {
const zodObjectSchemaFields = this.fields
let fields = this.fields;

// exclude discriminator fields from create/update input schemas
const createUpdateMatch = this.delegateCreateUpdateInputRegex.exec(this.name);
if (createUpdateMatch) {
const modelName = createUpdateMatch[1];
const dataModel = this.zmodel.declarations.find(
(d): d is DataModel => isDataModel(d) && d.name === modelName
);
if (dataModel) {
const discriminatorFields = dataModel.fields.filter(isDiscriminatorField);
if (discriminatorFields.length > 0) {
fields = fields.filter((field) => {
return !discriminatorFields.some(
(discriminatorField) => discriminatorField.name === field.name
);
});
}
}
}

const zodObjectSchemaFields = fields
.map((field) => this.generateObjectSchemaField(field, generateUnchecked))
.flatMap((item) => item)
.map((item) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,94 @@ describe('Polymorphic Plugin Interaction Test', () => {
extraDependencies: ['@trpc/client', '@trpc/server', '@trpc/react-query'],
});
});

it('zod', async () => {
const { zodSchemas } = await loadSchema(POLYMORPHIC_SCHEMA, { fullZod: true });

// model schema
expect(
zodSchemas.models.AssetSchema.parse({
id: 1,
assetType: 'video',
createdAt: new Date(),
viewCount: 100,
})
).toBeTruthy();

expect(
zodSchemas.models.AssetSchema.parse({
id: 1,
assetType: 'video',
createdAt: new Date(),
viewCount: 100,
videoType: 'ratedVideo', // should be stripped
}).videoType
).toBeUndefined();

expect(
zodSchemas.models.VideoSchema.parse({
id: 1,
assetType: 'video',
videoType: 'ratedVideo',
duration: 100,
url: 'http://example.com',
createdAt: new Date(),
viewCount: 100,
})
).toBeTruthy();

expect(() =>
zodSchemas.models.VideoSchema.parse({
id: 1,
assetType: 'video',
videoType: 'ratedVideo',
url: 'http://example.com',
createdAt: new Date(),
viewCount: 100,
})
).toThrow('duration');

// create schema
expect(
zodSchemas.models.VideoCreateSchema.parse({
duration: 100,
url: 'http://example.com',
}).assetType // discriminator should not be set
).toBeUndefined();

// update schema
expect(
zodSchemas.models.VideoUpdateSchema.parse({
duration: 100,
url: 'http://example.com',
}).assetType // discriminator should not be set
).toBeUndefined();

// prisma create schema
expect(
zodSchemas.models.VideoPrismaCreateSchema.strip().parse({
assetType: 'video',
}).assetType // discriminator should not be set
).toBeUndefined();

// input object schema
expect(
zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({
duration: 100,
viewCount: 200,
url: 'http://www.example.com',
rating: 5,
})
).toBeTruthy();

expect(() =>
zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({
duration: 100,
viewCount: 200,
url: 'http://www.example.com',
rating: 5,
videoType: 'ratedVideo',
})
).toThrow('videoType');
});
});
Loading