Skip to content

Commit 1b7448f

Browse files
authored
feat: Add @encrypted enhancer (#1922)
1 parent a0e2b53 commit 1b7448f

File tree

6 files changed

+319
-3
lines changed

6 files changed

+319
-3
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../node/encrypted.ts

packages/runtime/src/enhancements/node/create-enhancement.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ import { withJsonProcessor } from './json-processor';
1414
import { Logger } from './logger';
1515
import { withOmit } from './omit';
1616
import { withPassword } from './password';
17+
import { withEncrypted } from './encrypted';
1718
import { policyProcessIncludeRelationPayload, withPolicy } from './policy';
1819
import type { PolicyDef } from './types';
1920

2021
/**
2122
* All enhancement kinds
2223
*/
23-
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate'];
24+
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted'];
2425

2526
/**
2627
* Options for {@link createEnhancement}
@@ -100,6 +101,7 @@ export function createEnhancement<DbClient extends object>(
100101
}
101102

102103
const hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password'));
104+
const hasEncrypted = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@encrypted'));
103105
const hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit'));
104106
const hasDefaultAuth = allFields.some((field) => field.defaultValueProvider);
105107
const hasTypeDefField = allFields.some((field) => field.isTypeDef);
@@ -120,13 +122,22 @@ export function createEnhancement<DbClient extends object>(
120122
}
121123
}
122124

123-
// password enhancement must be applied prior to policy because it changes then length of the field
125+
// password and encrypted enhancement must be applied prior to policy because it changes then length of the field
124126
// and can break validation rules like `@length`
125127
if (hasPassword && kinds.includes('password')) {
126128
// @password proxy
127129
result = withPassword(result, options);
128130
}
129131

132+
if (hasEncrypted && kinds.includes('encrypted')) {
133+
if (!options.encryption) {
134+
throw new Error('Encryption options are required for @encrypted enhancement');
135+
}
136+
137+
// @encrypted proxy
138+
result = withEncrypted(result, options);
139+
}
140+
130141
// 'policy' and 'validation' enhancements are both enabled by `withPolicy`
131142
if (kinds.includes('policy') || kinds.includes('validation')) {
132143
result = withPolicy(result, options, context);
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/* eslint-disable @typescript-eslint/no-explicit-any */
2+
/* eslint-disable @typescript-eslint/no-unused-vars */
3+
4+
import {
5+
FieldInfo,
6+
NestedWriteVisitor,
7+
enumerate,
8+
getModelFields,
9+
resolveField,
10+
type PrismaWriteActionType,
11+
} from '../../cross';
12+
import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types';
13+
import { InternalEnhancementOptions } from './create-enhancement';
14+
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
15+
import { QueryUtils } from './query-utils';
16+
17+
/**
18+
* Gets an enhanced Prisma client that supports `@encrypted` attribute.
19+
*
20+
* @private
21+
*/
22+
export function withEncrypted<DbClient extends object = any>(
23+
prisma: DbClient,
24+
options: InternalEnhancementOptions
25+
): DbClient {
26+
return makeProxy(
27+
prisma,
28+
options.modelMeta,
29+
(_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options),
30+
'encrypted'
31+
);
32+
}
33+
34+
class EncryptedHandler extends DefaultPrismaProxyHandler {
35+
private queryUtils: QueryUtils;
36+
private encoder = new TextEncoder();
37+
private decoder = new TextDecoder();
38+
39+
constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
40+
super(prisma, model, options);
41+
42+
this.queryUtils = new QueryUtils(prisma, options);
43+
44+
if (!options.encryption) throw new Error('Encryption options must be provided');
45+
46+
if (this.isCustomEncryption(options.encryption!)) {
47+
if (!options.encryption.encrypt || !options.encryption.decrypt)
48+
throw new Error('Custom encryption must provide encrypt and decrypt functions');
49+
} else {
50+
if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided');
51+
if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes');
52+
}
53+
}
54+
55+
private async getKey(secret: Uint8Array): Promise<CryptoKey> {
56+
return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']);
57+
}
58+
59+
private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption {
60+
return 'encrypt' in encryption && 'decrypt' in encryption;
61+
}
62+
63+
private async encrypt(field: FieldInfo, data: string): Promise<string> {
64+
if (this.isCustomEncryption(this.options.encryption!)) {
65+
return this.options.encryption.encrypt(this.model, field, data);
66+
}
67+
68+
const key = await this.getKey(this.options.encryption!.encryptionKey);
69+
const iv = crypto.getRandomValues(new Uint8Array(12));
70+
71+
const encrypted = await crypto.subtle.encrypt(
72+
{
73+
name: 'AES-GCM',
74+
iv,
75+
},
76+
key,
77+
this.encoder.encode(data)
78+
);
79+
80+
// Combine IV and encrypted data into a single array of bytes
81+
const bytes = [...iv, ...new Uint8Array(encrypted)];
82+
83+
// Convert bytes to base64 string
84+
return btoa(String.fromCharCode(...bytes));
85+
}
86+
87+
private async decrypt(field: FieldInfo, data: string): Promise<string> {
88+
if (this.isCustomEncryption(this.options.encryption!)) {
89+
return this.options.encryption.decrypt(this.model, field, data);
90+
}
91+
92+
const key = await this.getKey(this.options.encryption!.encryptionKey);
93+
94+
// Convert base64 back to bytes
95+
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));
96+
97+
// First 12 bytes are IV, rest is encrypted data
98+
const decrypted = await crypto.subtle.decrypt(
99+
{
100+
name: 'AES-GCM',
101+
iv: bytes.slice(0, 12),
102+
},
103+
key,
104+
bytes.slice(12)
105+
);
106+
107+
return this.decoder.decode(decrypted);
108+
}
109+
110+
// base override
111+
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
112+
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
113+
if (args && args.data && actionsOfInterest.includes(action)) {
114+
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
115+
}
116+
return args;
117+
}
118+
119+
// base override
120+
protected async processResultEntity<T>(method: PrismaProxyActions, data: T): Promise<T> {
121+
if (!data || typeof data !== 'object') {
122+
return data;
123+
}
124+
125+
for (const value of enumerate(data)) {
126+
await this.doPostProcess(value, this.model);
127+
}
128+
129+
return data;
130+
}
131+
132+
private async doPostProcess(entityData: any, model: string) {
133+
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);
134+
135+
for (const field of getModelFields(entityData)) {
136+
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);
137+
138+
if (!fieldInfo) {
139+
continue;
140+
}
141+
142+
const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
143+
if (shouldDecrypt) {
144+
// Don't decrypt null, undefined or empty string values
145+
if (!entityData[field]) continue;
146+
147+
try {
148+
entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
149+
} catch (error) {
150+
console.warn('Decryption failed, keeping original value:', error);
151+
}
152+
}
153+
}
154+
}
155+
156+
private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
157+
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
158+
field: async (field, _action, data, context) => {
159+
// Don't encrypt null, undefined or empty string values
160+
if (!data) return;
161+
162+
const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted');
163+
if (encAttr && field.type === 'String') {
164+
try {
165+
context.parent[field.name] = await this.encrypt(field, data);
166+
} catch (error) {
167+
throw new Error(`Encryption failed for field ${field.name}: ${error}`);
168+
}
169+
}
170+
},
171+
});
172+
173+
await visitor.visit(model, action, args);
174+
}
175+
}

packages/runtime/src/types.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
22

33
import type { z } from 'zod';
4+
import { FieldInfo } from './cross';
45

56
export type PrismaPromise<T> = Promise<T> & Record<string, (args?: any) => PrismaPromise<any>>;
67

@@ -133,6 +134,11 @@ export type EnhancementOptions = {
133134
* The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack.
134135
*/
135136
transactionIsolationLevel?: TransactionIsolationLevel;
137+
138+
/**
139+
* The encryption options for using the `encrypted` enhancement.
140+
*/
141+
encryption?: SimpleEncryption | CustomEncryption;
136142
};
137143

138144
/**
@@ -145,7 +151,7 @@ export type EnhancementContext<User extends AuthUser = AuthUser> = {
145151
/**
146152
* Kinds of enhancements to `PrismaClient`
147153
*/
148-
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate';
154+
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted';
149155

150156
/**
151157
* Function for transforming errors.
@@ -166,3 +172,10 @@ export type ZodSchemas = {
166172
*/
167173
input?: Record<string, Record<string, z.ZodSchema>>;
168174
};
175+
176+
export type CustomEncryption = {
177+
encrypt: (model: string, field: FieldInfo, plain: string) => Promise<string>;
178+
decrypt: (model: string, field: FieldInfo, cipher: string) => Promise<string>;
179+
};
180+
181+
export type SimpleEncryption = { encryptionKey: Uint8Array };

packages/schema/src/res/stdlib.zmodel

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,14 @@ attribute @@auth() @@@supportTypeDef
575575
*/
576576
attribute @password(saltLength: Int?, salt: String?) @@@targetField([StringField])
577577

578+
579+
/**
580+
* Indicates that the field is encrypted when storing in the DB and should be decrypted when read
581+
*
582+
* ZenStack uses the Web Crypto API to encrypt and decrypt the field.
583+
*/
584+
attribute @encrypted() @@@targetField([StringField])
585+
578586
/**
579587
* Indicates that the field should be omitted when read from the generated services.
580588
*/
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import { FieldInfo } from '@zenstackhq/runtime';
2+
import { loadSchema } from '@zenstackhq/testtools';
3+
import path from 'path';
4+
5+
describe('Encrypted test', () => {
6+
let origDir: string;
7+
8+
beforeAll(async () => {
9+
origDir = path.resolve('.');
10+
});
11+
12+
afterEach(async () => {
13+
process.chdir(origDir);
14+
});
15+
16+
it('Simple encryption test', async () => {
17+
const { enhance } = await loadSchema(`
18+
model User {
19+
id String @id @default(cuid())
20+
encrypted_value String @encrypted()
21+
22+
@@allow('all', true)
23+
}`);
24+
25+
const sudoDb = enhance(undefined, { kinds: [] });
26+
const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64'));
27+
28+
const db = enhance(undefined, {
29+
kinds: ['encrypted'],
30+
encryption: { encryptionKey },
31+
});
32+
33+
const create = await db.user.create({
34+
data: {
35+
id: '1',
36+
encrypted_value: 'abc123',
37+
},
38+
});
39+
40+
const read = await db.user.findUnique({
41+
where: {
42+
id: '1',
43+
},
44+
});
45+
46+
const sudoRead = await sudoDb.user.findUnique({
47+
where: {
48+
id: '1',
49+
},
50+
});
51+
52+
expect(create.encrypted_value).toBe('abc123');
53+
expect(read.encrypted_value).toBe('abc123');
54+
expect(sudoRead.encrypted_value).not.toBe('abc123');
55+
});
56+
57+
it('Custom encryption test', async () => {
58+
const { enhance } = await loadSchema(`
59+
model User {
60+
id String @id @default(cuid())
61+
encrypted_value String @encrypted()
62+
63+
@@allow('all', true)
64+
}`);
65+
66+
const sudoDb = enhance(undefined, { kinds: [] });
67+
const db = enhance(undefined, {
68+
kinds: ['encrypted'],
69+
encryption: {
70+
encrypt: async (model: string, field: FieldInfo, data: string) => {
71+
// Add _enc to the end of the input
72+
return data + '_enc';
73+
},
74+
decrypt: async (model: string, field: FieldInfo, cipher: string) => {
75+
// Remove _enc from the end of the input explicitly
76+
if (cipher.endsWith('_enc')) {
77+
return cipher.slice(0, -4); // Remove last 4 characters (_enc)
78+
}
79+
80+
return cipher;
81+
},
82+
},
83+
});
84+
85+
const create = await db.user.create({
86+
data: {
87+
id: '1',
88+
encrypted_value: 'abc123',
89+
},
90+
});
91+
92+
const read = await db.user.findUnique({
93+
where: {
94+
id: '1',
95+
},
96+
});
97+
98+
const sudoRead = await sudoDb.user.findUnique({
99+
where: {
100+
id: '1',
101+
},
102+
});
103+
104+
expect(create.encrypted_value).toBe('abc123');
105+
expect(read.encrypted_value).toBe('abc123');
106+
expect(sudoRead.encrypted_value).toBe('abc123_enc');
107+
});
108+
});

0 commit comments

Comments
 (0)