10
10
*/
11
11
12
12
#include <linux/module.h>
13
+ #include <linux/mpi.h>
13
14
#include <crypto/internal/rsa.h>
14
15
#include <crypto/internal/akcipher.h>
15
16
#include <crypto/akcipher.h>
16
17
#include <crypto/algapi.h>
17
18
19
+ struct rsa_mpi_key {
20
+ MPI n ;
21
+ MPI e ;
22
+ MPI d ;
23
+ };
24
+
18
25
/*
19
26
* RSAEP function [RFC3447 sec 5.1.1]
20
27
* c = m^e mod n;
21
28
*/
22
- static int _rsa_enc (const struct rsa_key * key , MPI c , MPI m )
29
+ static int _rsa_enc (const struct rsa_mpi_key * key , MPI c , MPI m )
23
30
{
24
31
/* (1) Validate 0 <= m < n */
25
32
if (mpi_cmp_ui (m , 0 ) < 0 || mpi_cmp (m , key -> n ) >= 0 )
@@ -33,7 +40,7 @@ static int _rsa_enc(const struct rsa_key *key, MPI c, MPI m)
33
40
* RSADP function [RFC3447 sec 5.1.2]
34
41
* m = c^d mod n;
35
42
*/
36
- static int _rsa_dec (const struct rsa_key * key , MPI m , MPI c )
43
+ static int _rsa_dec (const struct rsa_mpi_key * key , MPI m , MPI c )
37
44
{
38
45
/* (1) Validate 0 <= c < n */
39
46
if (mpi_cmp_ui (c , 0 ) < 0 || mpi_cmp (c , key -> n ) >= 0 )
@@ -47,7 +54,7 @@ static int _rsa_dec(const struct rsa_key *key, MPI m, MPI c)
47
54
* RSASP1 function [RFC3447 sec 5.2.1]
48
55
* s = m^d mod n
49
56
*/
50
- static int _rsa_sign (const struct rsa_key * key , MPI s , MPI m )
57
+ static int _rsa_sign (const struct rsa_mpi_key * key , MPI s , MPI m )
51
58
{
52
59
/* (1) Validate 0 <= m < n */
53
60
if (mpi_cmp_ui (m , 0 ) < 0 || mpi_cmp (m , key -> n ) >= 0 )
@@ -61,7 +68,7 @@ static int _rsa_sign(const struct rsa_key *key, MPI s, MPI m)
61
68
* RSAVP1 function [RFC3447 sec 5.2.2]
62
69
* m = s^e mod n;
63
70
*/
64
- static int _rsa_verify (const struct rsa_key * key , MPI m , MPI s )
71
+ static int _rsa_verify (const struct rsa_mpi_key * key , MPI m , MPI s )
65
72
{
66
73
/* (1) Validate 0 <= s < n */
67
74
if (mpi_cmp_ui (s , 0 ) < 0 || mpi_cmp (s , key -> n ) >= 0 )
@@ -71,15 +78,15 @@ static int _rsa_verify(const struct rsa_key *key, MPI m, MPI s)
71
78
return mpi_powm (m , s , key -> e , key -> n );
72
79
}
73
80
74
- static inline struct rsa_key * rsa_get_key (struct crypto_akcipher * tfm )
81
+ static inline struct rsa_mpi_key * rsa_get_key (struct crypto_akcipher * tfm )
75
82
{
76
83
return akcipher_tfm_ctx (tfm );
77
84
}
78
85
79
86
static int rsa_enc (struct akcipher_request * req )
80
87
{
81
88
struct crypto_akcipher * tfm = crypto_akcipher_reqtfm (req );
82
- const struct rsa_key * pkey = rsa_get_key (tfm );
89
+ const struct rsa_mpi_key * pkey = rsa_get_key (tfm );
83
90
MPI m , c = mpi_alloc (0 );
84
91
int ret = 0 ;
85
92
int sign ;
@@ -118,7 +125,7 @@ static int rsa_enc(struct akcipher_request *req)
118
125
static int rsa_dec (struct akcipher_request * req )
119
126
{
120
127
struct crypto_akcipher * tfm = crypto_akcipher_reqtfm (req );
121
- const struct rsa_key * pkey = rsa_get_key (tfm );
128
+ const struct rsa_mpi_key * pkey = rsa_get_key (tfm );
122
129
MPI c , m = mpi_alloc (0 );
123
130
int ret = 0 ;
124
131
int sign ;
@@ -156,7 +163,7 @@ static int rsa_dec(struct akcipher_request *req)
156
163
static int rsa_sign (struct akcipher_request * req )
157
164
{
158
165
struct crypto_akcipher * tfm = crypto_akcipher_reqtfm (req );
159
- const struct rsa_key * pkey = rsa_get_key (tfm );
166
+ const struct rsa_mpi_key * pkey = rsa_get_key (tfm );
160
167
MPI m , s = mpi_alloc (0 );
161
168
int ret = 0 ;
162
169
int sign ;
@@ -195,7 +202,7 @@ static int rsa_sign(struct akcipher_request *req)
195
202
static int rsa_verify (struct akcipher_request * req )
196
203
{
197
204
struct crypto_akcipher * tfm = crypto_akcipher_reqtfm (req );
198
- const struct rsa_key * pkey = rsa_get_key (tfm );
205
+ const struct rsa_mpi_key * pkey = rsa_get_key (tfm );
199
206
MPI s , m = mpi_alloc (0 );
200
207
int ret = 0 ;
201
208
int sign ;
@@ -233,6 +240,16 @@ static int rsa_verify(struct akcipher_request *req)
233
240
return ret ;
234
241
}
235
242
243
+ static void rsa_free_mpi_key (struct rsa_mpi_key * key )
244
+ {
245
+ mpi_free (key -> d );
246
+ mpi_free (key -> e );
247
+ mpi_free (key -> n );
248
+ key -> d = NULL ;
249
+ key -> e = NULL ;
250
+ key -> n = NULL ;
251
+ }
252
+
236
253
static int rsa_check_key_length (unsigned int len )
237
254
{
238
255
switch (len ) {
@@ -251,49 +268,87 @@ static int rsa_check_key_length(unsigned int len)
251
268
static int rsa_set_pub_key (struct crypto_akcipher * tfm , const void * key ,
252
269
unsigned int keylen )
253
270
{
254
- struct rsa_key * pkey = akcipher_tfm_ctx (tfm );
271
+ struct rsa_mpi_key * mpi_key = akcipher_tfm_ctx (tfm );
272
+ struct rsa_key raw_key = {0 };
255
273
int ret ;
256
274
257
- ret = rsa_parse_pub_key (pkey , key , keylen );
275
+ /* Free the old MPI key if any */
276
+ rsa_free_mpi_key (mpi_key );
277
+
278
+ ret = rsa_parse_pub_key (& raw_key , key , keylen );
258
279
if (ret )
259
280
return ret ;
260
281
261
- if (rsa_check_key_length (mpi_get_size (pkey -> n ) << 3 )) {
262
- rsa_free_key (pkey );
263
- ret = - EINVAL ;
282
+ mpi_key -> e = mpi_read_raw_data (raw_key .e , raw_key .e_sz );
283
+ if (!mpi_key -> e )
284
+ goto err ;
285
+
286
+ mpi_key -> n = mpi_read_raw_data (raw_key .n , raw_key .n_sz );
287
+ if (!mpi_key -> n )
288
+ goto err ;
289
+
290
+ if (rsa_check_key_length (mpi_get_size (mpi_key -> n ) << 3 )) {
291
+ rsa_free_mpi_key (mpi_key );
292
+ return - EINVAL ;
264
293
}
265
- return ret ;
294
+
295
+ return 0 ;
296
+
297
+ err :
298
+ rsa_free_mpi_key (mpi_key );
299
+ return - ENOMEM ;
266
300
}
267
301
268
302
static int rsa_set_priv_key (struct crypto_akcipher * tfm , const void * key ,
269
303
unsigned int keylen )
270
304
{
271
- struct rsa_key * pkey = akcipher_tfm_ctx (tfm );
305
+ struct rsa_mpi_key * mpi_key = akcipher_tfm_ctx (tfm );
306
+ struct rsa_key raw_key = {0 };
272
307
int ret ;
273
308
274
- ret = rsa_parse_priv_key (pkey , key , keylen );
309
+ /* Free the old MPI key if any */
310
+ rsa_free_mpi_key (mpi_key );
311
+
312
+ ret = rsa_parse_priv_key (& raw_key , key , keylen );
275
313
if (ret )
276
314
return ret ;
277
315
278
- if (rsa_check_key_length (mpi_get_size (pkey -> n ) << 3 )) {
279
- rsa_free_key (pkey );
280
- ret = - EINVAL ;
316
+ mpi_key -> d = mpi_read_raw_data (raw_key .d , raw_key .d_sz );
317
+ if (!mpi_key -> d )
318
+ goto err ;
319
+
320
+ mpi_key -> e = mpi_read_raw_data (raw_key .e , raw_key .e_sz );
321
+ if (!mpi_key -> e )
322
+ goto err ;
323
+
324
+ mpi_key -> n = mpi_read_raw_data (raw_key .n , raw_key .n_sz );
325
+ if (!mpi_key -> n )
326
+ goto err ;
327
+
328
+ if (rsa_check_key_length (mpi_get_size (mpi_key -> n ) << 3 )) {
329
+ rsa_free_mpi_key (mpi_key );
330
+ return - EINVAL ;
281
331
}
282
- return ret ;
332
+
333
+ return 0 ;
334
+
335
+ err :
336
+ rsa_free_mpi_key (mpi_key );
337
+ return - ENOMEM ;
283
338
}
284
339
285
340
static int rsa_max_size (struct crypto_akcipher * tfm )
286
341
{
287
- struct rsa_key * pkey = akcipher_tfm_ctx (tfm );
342
+ struct rsa_mpi_key * pkey = akcipher_tfm_ctx (tfm );
288
343
289
344
return pkey -> n ? mpi_get_size (pkey -> n ) : - EINVAL ;
290
345
}
291
346
292
347
static void rsa_exit_tfm (struct crypto_akcipher * tfm )
293
348
{
294
- struct rsa_key * pkey = akcipher_tfm_ctx (tfm );
349
+ struct rsa_mpi_key * pkey = akcipher_tfm_ctx (tfm );
295
350
296
- rsa_free_key (pkey );
351
+ rsa_free_mpi_key (pkey );
297
352
}
298
353
299
354
static struct akcipher_alg rsa = {
@@ -310,7 +365,7 @@ static struct akcipher_alg rsa = {
310
365
.cra_driver_name = "rsa-generic" ,
311
366
.cra_priority = 100 ,
312
367
.cra_module = THIS_MODULE ,
313
- .cra_ctxsize = sizeof (struct rsa_key ),
368
+ .cra_ctxsize = sizeof (struct rsa_mpi_key ),
314
369
},
315
370
};
316
371
0 commit comments