Skip to content

fix(runtime): intercepts $extends to reproxy its result to make sure enhancements persist #1847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions packages/runtime/src/enhancements/node/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,25 @@ export function makeProxy<T extends PrismaProxyHandler>(
}
}

if (prop === '$extends') {
// Prisma's `$extends` API returns a new client instance, we need to recreate
// a proxy around it
const $extends = Reflect.get(target, prop, receiver);
if ($extends && typeof $extends === 'function') {
return (...args: any[]) => {
const result = $extends.bind(target)(...args);
if (!result[PRISMA_PROXY_ENHANCER]) {
return makeProxy(result, modelMeta, makeHandler, name + '$ext', errorTransformer);
} else {
// avoid double wrapping
return result;
}
};
} else {
return $extends;
}
}

if (typeof prop !== 'string' || prop.startsWith('$') || !models.includes(prop.toLowerCase())) {
// skip non-model fields
return Reflect.get(target, prop, receiver);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.getAll()).resolves.toHaveLength(2);

// FIXME: extending an enhanced client doesn't work for this case
// const db1 = enhance(prisma).$extends(ext);
// await expect(db1.model.getAll()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
});

it('one model new method', async () => {
Expand Down Expand Up @@ -84,9 +80,9 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.getAll()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
});

it('add client method', async () => {
Expand Down Expand Up @@ -115,8 +111,11 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
xprisma.$log('abc');
enhanceRaw(prisma).$extends(ext).$log('abc');
expect(logged).toBeTruthy();

logged = false;
enhanceRaw(prisma.$extends(ext)).$log('abc');
expect(logged).toBeTruthy();
});

Expand All @@ -143,7 +142,6 @@ describe('With Policy: client extensions', () => {
query: {
model: {
async findMany({ args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
return query(args);
},
Expand All @@ -152,9 +150,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('query override all models', async () => {
Expand All @@ -180,7 +177,6 @@ describe('With Policy: client extensions', () => {
query: {
$allModels: {
async findMany({ args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
console.log('findMany args:', args);
return query(args);
Expand All @@ -190,9 +186,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('query override all operations', async () => {
Expand All @@ -218,7 +213,6 @@ describe('With Policy: client extensions', () => {
query: {
model: {
async $allOperations({ operation, args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
console.log(`${operation} args:`, args);
return query(args);
Expand All @@ -228,9 +222,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('query override everything', async () => {
Expand All @@ -255,7 +248,6 @@ describe('With Policy: client extensions', () => {
name: 'prisma-extension-queryOverride',
query: {
async $allOperations({ operation, args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
console.log(`${operation} args:`, args);
return query(args);
Expand All @@ -264,9 +256,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('result mutation', async () => {
Expand Down Expand Up @@ -301,11 +292,9 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
const r = await db.model.findMany();
expect(r).toHaveLength(1);
expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ value: 2 })]));
const expected = [expect.objectContaining({ value: 2 })];
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
});

it('result custom fields', async () => {
Expand Down Expand Up @@ -339,10 +328,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
const r = await db.model.findMany();
expect(r).toHaveLength(1);
expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ doubleValue: 2 })]));
const expected = [expect.objectContaining({ doubleValue: 2 })];
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
});
});
90 changes: 90 additions & 0 deletions tests/regression/tests/issue-1859.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue 1859', () => {
it('extend enhanced client', async () => {
const { enhance, prisma } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean

@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const db = enhance();
await expect(db.post.findMany()).resolves.toHaveLength(1);

const extended = db.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await db.post.findMany(args) };
},
},
},
});

await expect(extended.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [{ id: 1, title: 'post1', published: true }],
});
await expect(extended.post.findMany()).resolves.toHaveLength(1);
});

it('enhance extended client', async () => {
const { enhanceRaw, prisma, prismaModule } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean

@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const ext = prismaModule.defineExtension((_prisma: any) => {
return _prisma.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await prisma.post.findMany(args) };
},
},
},
});
});

await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});

const enhanced = enhanceRaw(prisma.$extends(ext));
await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
// findManyListView internally uses the un-enhanced client
await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});
});
});
90 changes: 90 additions & 0 deletions tests/regression/tests/issue-prisma-extension.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue prisma extension', () => {
it('extend enhanced client', async () => {
const { enhance, prisma } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean

@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const db = enhance();
await expect(db.post.findMany()).resolves.toHaveLength(1);

const extended = db.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await db.post.findMany(args) };
},
},
},
});

await expect(extended.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [{ id: 1, title: 'post1', published: true }],
});
await expect(extended.post.findMany()).resolves.toHaveLength(1);
});

it('enhance extended client', async () => {
const { enhanceRaw, prisma, prismaModule } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean

@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const ext = prismaModule.defineExtension((_prisma: any) => {
return _prisma.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await prisma.post.findMany(args) };
},
},
},
});
});

await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});

const enhanced = enhanceRaw(prisma.$extends(ext));
await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
// findManyListView internally uses the un-enhanced client
await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});
});
});
Loading