Skip to content

fix: several issues with using auth() in @default #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion packages/runtime/src/cross/model-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ export type FieldInfo = {
isForeignKey?: boolean;

/**
* Mapping from foreign key field names to relation field names
* If the field is a foreign key field, the field name of the corresponding relation field.
* Only available on foreign key fields.
*/
relationField?: string;

/**
* Mapping from foreign key field names to relation field names.
* Only available on relation fields.
*/
foreignKeyMapping?: Record<string, string>;

Expand Down
46 changes: 44 additions & 2 deletions packages/runtime/src/enhancements/default-auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import deepcopy from 'deepcopy';
import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross';
import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields, requireField } from '../cross';
import { DbClientContract } from '../types';
import { EnhancementContext, InternalEnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { isUnsafeMutate } from './utils';

/**
* Gets an enhanced Prisma client that supports `@default(auth())` attribute.
Expand Down Expand Up @@ -68,7 +69,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo);
if (authDefaultValue !== undefined) {
// set field value extracted from `auth()`
data[fieldInfo.name] = authDefaultValue;
this.setAuthDefaultValue(fieldInfo, model, data, authDefaultValue);
}
}
};
Expand All @@ -90,6 +91,47 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
return newArgs;
}

private setAuthDefaultValue(fieldInfo: FieldInfo, model: string, data: any, authDefaultValue: unknown) {
if (fieldInfo.isForeignKey && !isUnsafeMutate(model, data, this.options.modelMeta)) {
// if the field is a fk, and the create payload is not unsafe, we need to translate
// the fk field setting to a `connect` of the corresponding relation field
const relFieldName = fieldInfo.relationField;
if (!relFieldName) {
throw new Error(
`Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found`
);
}
const relationField = requireField(this.options.modelMeta, model, relFieldName);

// construct a `{ connect: { ... } }` payload
let connect = data[relationField.name]?.connect;
if (!connect) {
connect = {};
data[relationField.name] = { connect };
}

// sets the opposite fk field to value `authDefaultValue`
const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo);
if (!oppositeFkFieldName) {
throw new Error(
`Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\``
);
}
connect[oppositeFkFieldName] = authDefaultValue;
} else {
// set default value directly
data[fieldInfo.name] = authDefaultValue;
}
}

private getOppositeFkFieldName(relationField: FieldInfo, fieldInfo: FieldInfo) {
if (!relationField.foreignKeyMapping) {
return undefined;
}
const entry = Object.entries(relationField.foreignKeyMapping).find(([, v]) => v === fieldInfo.name);
return entry?.[0];
}

private getDefaultValueFromAuth(fieldInfo: FieldInfo) {
if (!this.userContext) {
throw new Error(`Evaluating default value of field \`${fieldInfo.name}\` requires a user context`);
Expand Down
21 changes: 2 additions & 19 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import type { EnhancementContext, InternalEnhancementOptions } from '../create-e
import { Logger } from '../logger';
import { PrismaProxyHandler } from '../proxy';
import { QueryUtils } from '../query-utils';
import { formatObject, prismaClientValidationError } from '../utils';
import { formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
import { PolicyUtil } from './policy-utils';
import { createDeferredPromise } from './promise';

Expand Down Expand Up @@ -691,7 +691,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
// operations. E.g.:
// - safe: { data: { user: { connect: { id: 1 }} } }
// - unsafe: { data: { userId: 1 } }
const unsafe = this.isUnsafeMutate(model, args);
const unsafe = isUnsafeMutate(model, args, this.modelMeta);

// handles the connection to upstream entity
const reversedQuery = this.policyUtils.buildReversedQuery(context, true, unsafe);
Expand Down Expand Up @@ -1083,23 +1083,6 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
}

private isUnsafeMutate(model: string, args: any) {
if (!args) {
return false;
}
for (const k of Object.keys(args)) {
const field = resolveField(this.modelMeta, model, k);
if (field && (this.isAutoIncrementIdField(field) || field.isForeignKey)) {
return true;
}
}
return false;
}

private isAutoIncrementIdField(field: FieldInfo) {
return field.isId && field.isAutoIncrement;
}

async updateMany(args: any) {
if (!args) {
throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required');
Expand Down
19 changes: 19 additions & 0 deletions packages/runtime/src/enhancements/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as util from 'util';
import { FieldInfo, ModelMeta, resolveField } from '..';
import type { DbClientContract } from '../types';

/**
Expand All @@ -22,3 +23,21 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo
export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error {
throw new prismaModule.PrismaClientUnknownRequestError(...args);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function isUnsafeMutate(model: string, args: any, modelMeta: ModelMeta) {
if (!args) {
return false;
}
for (const k of Object.keys(args)) {
const field = resolveField(modelMeta, model, k);
if (field && (isAutoIncrementIdField(field) || field.isForeignKey)) {
return true;
}
}
return false;
}

export function isAutoIncrementIdField(field: FieldInfo) {
return field.isId && field.isAutoIncrement;
}
26 changes: 22 additions & 4 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { name } from '..';
import { execPackage } from '../../../utils/exec-utils';
import { trackPrismaSchemaError } from '../../prisma';
import { PrismaSchemaGenerator } from '../../prisma/schema-generator';
import { isDefaultWithAuth } from '../enhancer-utils';

// information of delegate models and their sub models
type DelegateInfo = [DataModel, DataModel[]][];
Comment on lines 27 to 33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [48-48]

Detected multiple instances where user input might be indirectly influencing file paths through variables like outDir. Ensure that outDir and similar variables are sanitized or validated to prevent path traversal vulnerabilities.

- path.join(outDir, 'prisma.d.ts')
+ path.join(sanitizePath(outDir), 'prisma.d.ts')

Note: Apply similar changes to all instances where outDir is used in path operations.

Also applies to: 56-56, 64-64, 110-110, 136-136, 148-148, 163-163


📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [329-331]

The use of dynamically constructed regular expressions in removeCreateFromDelegateInput and removeDiscriminatorFromConcreteInput functions could lead to Regular Expression Denial-of-Service (ReDoS) vulnerabilities if the input is controlled by an attacker. Consider using hardcoded regexes or performing input validation to mitigate this risk.

- new RegExp(`\\${delegateModelNames.join('|')}(Unchecked)?(Create|Update).*Input`)
+ // Consider replacing with hardcoded regexes or validating `delegateModelNames` to ensure they do not form a vulnerable regex pattern.

Also applies to: 352-354

Expand All @@ -35,7 +36,7 @@ export async function generate(model: Model, options: PluginOptions, project: Pr
let logicalPrismaClientDir: string | undefined;
let dmmf: DMMF.Document | undefined;

if (hasDelegateModel(model)) {
if (needsLogicalClient(model)) {
// schema contains delegate models, need to generate a logical prisma schema
const result = await generateLogicalPrisma(model, options, outDir);

Expand Down Expand Up @@ -86,13 +87,23 @@ export function enhance<DbClient extends object>(prisma: DbClient, context?: Enh
return { dmmf };
}

function needsLogicalClient(model: Model) {
return hasDelegateModel(model) || hasAuthInDefault(model);
}

function hasDelegateModel(model: Model) {
const dataModels = getDataModels(model);
return dataModels.some(
(dm) => isDelegateModel(dm) && dataModels.some((sub) => sub.superTypes.some((base) => base.ref === dm))
);
}

function hasAuthInDefault(model: Model) {
return getDataModels(model).some((dm) =>
dm.fields.some((f) => f.attributes.some((attr) => isDefaultWithAuth(attr)))
);
}

async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) {
const prismaGenerator = new PrismaSchemaGenerator(model);
const prismaClientOutDir = './.logical-prisma-client';
Expand Down Expand Up @@ -152,12 +163,19 @@ async function processClientTypes(model: Model, prismaClientDir: string) {
const sfNew = project.createSourceFile(path.join(prismaClientDir, 'index-fixed.d.ts'), undefined, {
overwrite: true,
});
transform(sf, sfNew, delegateInfo);
sfNew.formatText();

if (delegateInfo.length > 0) {
// transform types for delegated models
transformDelegate(sf, sfNew, delegateInfo);
sfNew.formatText();
} else {
// just copy
sfNew.replaceWithText(sf.getFullText());
}
await sfNew.save();
}

function transform(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) {
function transformDelegate(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) {
// copy toplevel imports
sfNew.addImportDeclarations(sf.getImportDeclarations().map((n) => n.getStructure()));

Expand Down
20 changes: 20 additions & 0 deletions packages/schema/src/plugins/enhancer/enhancer-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { isAuthInvocation } from '@zenstackhq/sdk';
import type { DataModelFieldAttribute } from '@zenstackhq/sdk/ast';
import { streamAst } from 'langium';

/**
* Check if the given field attribute is a `@default` with `auth()` invocation
*/
export function isDefaultWithAuth(attr: DataModelFieldAttribute) {
if (attr.decl.ref?.name !== '@default') {
return false;
}

const expr = attr.args[0]?.value;
if (!expr) {
return false;
}

// find `auth()` in default value expression
return streamAst(expr).some(isAuthInvocation);
}
38 changes: 21 additions & 17 deletions packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,27 @@ import { getIdFields } from '../../utils/ast-utils';
import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime';
import {
getAttribute,
getForeignKeyFields,
getLiteral,
getPrismaVersion,
isAuthInvocation,
isDelegateModel,
isIdField,
isRelationshipField,
PluginError,
PluginOptions,
resolved,
ZModelCodeGenerator,
} from '@zenstackhq/sdk';
import fs from 'fs';
import { writeFile } from 'fs/promises';
import { streamAst } from 'langium';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
import semver from 'semver';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '.';
import { getStringLiteral } from '../../language-server/validator/utils';
import { execPackage } from '../../utils/exec-utils';
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
import {
AttributeArgValue,
ModelFieldType,
Expand Down Expand Up @@ -494,10 +495,27 @@ export class PrismaSchemaGenerator {

const type = new ModelFieldType(fieldType, field.type.array, field.type.optional);

if (this.mode === 'logical') {
if (field.attributes.some((attr) => isDefaultWithAuth(attr))) {
// field has `@default` with `auth()`, it should be set optional, and the
// default value setting is handled outside Prisma
type.optional = true;
}

if (isRelationshipField(field)) {
// if foreign key field has `@default` with `auth()`, the relation
// field should be set optional
const foreignKeyFields = getForeignKeyFields(field);
if (foreignKeyFields.some((fkField) => fkField.attributes.some((attr) => isDefaultWithAuth(attr)))) {
type.optional = true;
}
}
}

const attributes = field.attributes
.filter((attr) => this.isPrismaAttribute(attr))
// `@default` with `auth()` is handled outside Prisma
.filter((attr) => !this.isDefaultWithAuth(attr))
.filter((attr) => !isDefaultWithAuth(attr))
.filter(
(attr) =>
// when building physical schema, exclude `@default` for id fields inherited from delegate base
Expand All @@ -524,20 +542,6 @@ export class PrismaSchemaGenerator {
return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom);
}

private isDefaultWithAuth(attr: DataModelFieldAttribute) {
if (attr.decl.ref?.name !== '@default') {
return false;
}

const expr = attr.args[0]?.value;
if (!expr) {
return false;
}

// find `auth()` in default value expression
return streamAst(expr).some(isAuthInvocation);
}

private makeFieldAttribute(attr: DataModelFieldAttribute) {
const attrName = resolved(attr.decl).name;
if (attrName === FIELD_PASSTHROUGH_ATTR) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ describe('Attribute tests', () => {
`);

await loadModel(`
${ prelude }
${prelude}
model A {
id String @id
x String
Expand Down Expand Up @@ -1051,21 +1051,6 @@ describe('Attribute tests', () => {
}
`);

// expect(
// await loadModelWithError(`
// ${prelude}

// model User {
// id String @id
// name String
// }
// model B {
// id String @id
// userData String @default(auth())
// }
// `)
// ).toContain("Value is not assignable to parameter");

expect(
await loadModelWithError(`
${prelude}
Expand Down Expand Up @@ -1185,15 +1170,6 @@ describe('Attribute tests', () => {
});

it('incorrect function expression context', async () => {
// expect(
// await loadModelWithError(`
// ${prelude}
// model M {
// id String @id @default(auth())
// }
// `)
// ).toContain('function "auth" is not allowed in the current context: DefaultValue');

expect(
await loadModelWithError(`
${prelude}
Expand Down
7 changes: 6 additions & 1 deletion packages/sdk/src/model-meta-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
isIdField,
resolved,
TypeScriptExpressionTransformer,
getRelationField,
} from '.';

/**
Expand Down Expand Up @@ -247,6 +248,11 @@ function writeFields(
if (isForeignKeyField(f)) {
writer.write(`
isForeignKey: true,`);
const relationField = getRelationField(f);
if (relationField) {
writer.write(`
relationField: '${relationField.name}',`);
}
}

if (fkMapping && Object.keys(fkMapping).length > 0) {
Expand Down Expand Up @@ -408,7 +414,6 @@ function generateForeignKeyMapping(field: DataModelField) {
const fieldNames = fields.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined));
const referenceNames = references.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined));

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const result: Record<string, string> = {};
referenceNames.forEach((name, i) => {
if (name) {
Expand Down
Loading