Skip to content

Commit d9faaf3

Browse files
committed
Port RSA to rust
1 parent 543cf43 commit d9faaf3

File tree

12 files changed

+767
-838
lines changed

12 files changed

+767
-838
lines changed

src/cryptography/hazmat/backends/openssl/backend.py

Lines changed: 15 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
from cryptography.hazmat.backends.openssl import aead
1515
from cryptography.hazmat.backends.openssl.ciphers import _CipherContext
1616
from cryptography.hazmat.backends.openssl.cmac import _CMACContext
17-
from cryptography.hazmat.backends.openssl.rsa import (
18-
_RSAPrivateKey,
19-
_RSAPublicKey,
20-
)
2117
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
2218
from cryptography.hazmat.bindings.openssl import binding
2319
from cryptography.hazmat.primitives import hashes, serialization
@@ -63,7 +59,6 @@
6359
XTS,
6460
Mode,
6561
)
66-
from cryptography.hazmat.primitives.serialization import ssh
6762
from cryptography.hazmat.primitives.serialization.pkcs12 import (
6863
PBES,
6964
PKCS12Certificate,
@@ -358,24 +353,7 @@ def generate_rsa_private_key(
358353
self, public_exponent: int, key_size: int
359354
) -> rsa.RSAPrivateKey:
360355
rsa._verify_rsa_parameters(public_exponent, key_size)
361-
362-
rsa_cdata = self._lib.RSA_new()
363-
self.openssl_assert(rsa_cdata != self._ffi.NULL)
364-
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
365-
366-
bn = self._int_to_bn(public_exponent)
367-
bn = self._ffi.gc(bn, self._lib.BN_free)
368-
369-
res = self._lib.RSA_generate_key_ex(
370-
rsa_cdata, key_size, bn, self._ffi.NULL
371-
)
372-
self.openssl_assert(res == 1)
373-
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
374-
375-
# We can skip RSA key validation here since we just generated the key
376-
return _RSAPrivateKey(
377-
self, rsa_cdata, evp_pkey, unsafe_skip_rsa_key_validation=True
378-
)
356+
return rust_openssl.rsa.generate_private_key(public_exponent, key_size)
379357

380358
def generate_rsa_parameters_supported(
381359
self, public_exponent: int, key_size: int
@@ -401,46 +379,15 @@ def load_rsa_private_numbers(
401379
numbers.public_numbers.e,
402380
numbers.public_numbers.n,
403381
)
404-
rsa_cdata = self._lib.RSA_new()
405-
self.openssl_assert(rsa_cdata != self._ffi.NULL)
406-
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
407-
p = self._int_to_bn(numbers.p)
408-
q = self._int_to_bn(numbers.q)
409-
d = self._int_to_bn(numbers.d)
410-
dmp1 = self._int_to_bn(numbers.dmp1)
411-
dmq1 = self._int_to_bn(numbers.dmq1)
412-
iqmp = self._int_to_bn(numbers.iqmp)
413-
e = self._int_to_bn(numbers.public_numbers.e)
414-
n = self._int_to_bn(numbers.public_numbers.n)
415-
res = self._lib.RSA_set0_factors(rsa_cdata, p, q)
416-
self.openssl_assert(res == 1)
417-
res = self._lib.RSA_set0_key(rsa_cdata, n, e, d)
418-
self.openssl_assert(res == 1)
419-
res = self._lib.RSA_set0_crt_params(rsa_cdata, dmp1, dmq1, iqmp)
420-
self.openssl_assert(res == 1)
421-
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
422-
423-
return _RSAPrivateKey(
424-
self,
425-
rsa_cdata,
426-
evp_pkey,
427-
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
382+
return rust_openssl.rsa.from_private_numbers(
383+
numbers, unsafe_skip_rsa_key_validation
428384
)
429385

430386
def load_rsa_public_numbers(
431387
self, numbers: rsa.RSAPublicNumbers
432388
) -> rsa.RSAPublicKey:
433389
rsa._check_public_key_components(numbers.e, numbers.n)
434-
rsa_cdata = self._lib.RSA_new()
435-
self.openssl_assert(rsa_cdata != self._ffi.NULL)
436-
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
437-
e = self._int_to_bn(numbers.e)
438-
n = self._int_to_bn(numbers.n)
439-
res = self._lib.RSA_set0_key(rsa_cdata, n, e, self._ffi.NULL)
440-
self.openssl_assert(res == 1)
441-
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
442-
443-
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
390+
return rust_openssl.rsa.from_public_numbers(numbers)
444391

445392
def _create_evp_pkey_gc(self):
446393
evp_pkey = self._lib.EVP_PKEY_new()
@@ -500,13 +447,8 @@ def _evp_pkey_to_private_key(
500447
key_type = self._lib.EVP_PKEY_id(evp_pkey)
501448

502449
if key_type == self._lib.EVP_PKEY_RSA:
503-
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
504-
self.openssl_assert(rsa_cdata != self._ffi.NULL)
505-
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
506-
return _RSAPrivateKey(
507-
self,
508-
rsa_cdata,
509-
evp_pkey,
450+
return rust_openssl.rsa.private_key_from_ptr(
451+
int(self._ffi.cast("uintptr_t", evp_pkey)),
510452
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
511453
)
512454
elif (
@@ -573,10 +515,9 @@ def _evp_pkey_to_public_key(self, evp_pkey) -> PublicKeyTypes:
573515
key_type = self._lib.EVP_PKEY_id(evp_pkey)
574516

575517
if key_type == self._lib.EVP_PKEY_RSA:
576-
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
577-
self.openssl_assert(rsa_cdata != self._ffi.NULL)
578-
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
579-
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
518+
return rust_openssl.rsa.public_key_from_ptr(
519+
int(self._ffi.cast("uintptr_t", evp_pkey))
520+
)
580521
elif (
581522
key_type == self._lib.EVP_PKEY_RSA_PSS
582523
and not self._lib.CRYPTOGRAPHY_IS_LIBRESSL
@@ -733,7 +674,9 @@ def load_pem_public_key(self, data: bytes) -> PublicKeyTypes:
733674
if rsa_cdata != self._ffi.NULL:
734675
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
735676
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
736-
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
677+
return rust_openssl.rsa.public_key_from_ptr(
678+
int(self._ffi.cast("uintptr_t", evp_pkey))
679+
)
737680
else:
738681
self._handle_key_loading_error()
739682

@@ -796,7 +739,9 @@ def load_der_public_key(self, data: bytes) -> PublicKeyTypes:
796739
if rsa_cdata != self._ffi.NULL:
797740
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
798741
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
799-
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
742+
return rust_openssl.rsa.public_key_from_ptr(
743+
int(self._ffi.cast("uintptr_t", evp_pkey))
744+
)
800745
else:
801746
self._handle_key_loading_error()
802747

@@ -984,191 +929,6 @@ def elliptic_curve_exchange_algorithm_supported(
984929
algorithm, ec.ECDH
985930
)
986931

987-
def _private_key_bytes(
988-
self,
989-
encoding: serialization.Encoding,
990-
format: serialization.PrivateFormat,
991-
encryption_algorithm: serialization.KeySerializationEncryption,
992-
key,
993-
evp_pkey,
994-
cdata,
995-
) -> bytes:
996-
# validate argument types
997-
if not isinstance(encoding, serialization.Encoding):
998-
raise TypeError("encoding must be an item from the Encoding enum")
999-
if not isinstance(format, serialization.PrivateFormat):
1000-
raise TypeError(
1001-
"format must be an item from the PrivateFormat enum"
1002-
)
1003-
if not isinstance(
1004-
encryption_algorithm, serialization.KeySerializationEncryption
1005-
):
1006-
raise TypeError(
1007-
"Encryption algorithm must be a KeySerializationEncryption "
1008-
"instance"
1009-
)
1010-
1011-
# validate password
1012-
if isinstance(encryption_algorithm, serialization.NoEncryption):
1013-
password = b""
1014-
elif isinstance(
1015-
encryption_algorithm, serialization.BestAvailableEncryption
1016-
):
1017-
password = encryption_algorithm.password
1018-
if len(password) > 1023:
1019-
raise ValueError(
1020-
"Passwords longer than 1023 bytes are not supported by "
1021-
"this backend"
1022-
)
1023-
elif (
1024-
isinstance(
1025-
encryption_algorithm, serialization._KeySerializationEncryption
1026-
)
1027-
and encryption_algorithm._format
1028-
is format
1029-
is serialization.PrivateFormat.OpenSSH
1030-
):
1031-
password = encryption_algorithm.password
1032-
else:
1033-
raise ValueError("Unsupported encryption type")
1034-
1035-
# PKCS8 + PEM/DER
1036-
if format is serialization.PrivateFormat.PKCS8:
1037-
if encoding is serialization.Encoding.PEM:
1038-
write_bio = self._lib.PEM_write_bio_PKCS8PrivateKey
1039-
elif encoding is serialization.Encoding.DER:
1040-
write_bio = self._lib.i2d_PKCS8PrivateKey_bio
1041-
else:
1042-
raise ValueError("Unsupported encoding for PKCS8")
1043-
return self._private_key_bytes_via_bio(
1044-
write_bio, evp_pkey, password
1045-
)
1046-
1047-
# TraditionalOpenSSL + PEM/DER
1048-
if format is serialization.PrivateFormat.TraditionalOpenSSL:
1049-
if self._fips_enabled and not isinstance(
1050-
encryption_algorithm, serialization.NoEncryption
1051-
):
1052-
raise ValueError(
1053-
"Encrypted traditional OpenSSL format is not "
1054-
"supported in FIPS mode."
1055-
)
1056-
key_type = self._lib.EVP_PKEY_id(evp_pkey)
1057-
1058-
if encoding is serialization.Encoding.PEM:
1059-
assert key_type == self._lib.EVP_PKEY_RSA
1060-
write_bio = self._lib.PEM_write_bio_RSAPrivateKey
1061-
return self._private_key_bytes_via_bio(
1062-
write_bio, cdata, password
1063-
)
1064-
1065-
if encoding is serialization.Encoding.DER:
1066-
if password:
1067-
raise ValueError(
1068-
"Encryption is not supported for DER encoded "
1069-
"traditional OpenSSL keys"
1070-
)
1071-
assert key_type == self._lib.EVP_PKEY_RSA
1072-
write_bio = self._lib.i2d_RSAPrivateKey_bio
1073-
return self._bio_func_output(write_bio, cdata)
1074-
1075-
raise ValueError("Unsupported encoding for TraditionalOpenSSL")
1076-
1077-
# OpenSSH + PEM
1078-
if format is serialization.PrivateFormat.OpenSSH:
1079-
if encoding is serialization.Encoding.PEM:
1080-
return ssh._serialize_ssh_private_key(
1081-
key, password, encryption_algorithm
1082-
)
1083-
1084-
raise ValueError(
1085-
"OpenSSH private key format can only be used"
1086-
" with PEM encoding"
1087-
)
1088-
1089-
# Anything that key-specific code was supposed to handle earlier,
1090-
# like Raw.
1091-
raise ValueError("format is invalid with this key")
1092-
1093-
def _private_key_bytes_via_bio(
1094-
self, write_bio, evp_pkey, password
1095-
) -> bytes:
1096-
if not password:
1097-
evp_cipher = self._ffi.NULL
1098-
else:
1099-
# This is a curated value that we will update over time.
1100-
evp_cipher = self._lib.EVP_get_cipherbyname(b"aes-256-cbc")
1101-
1102-
return self._bio_func_output(
1103-
write_bio,
1104-
evp_pkey,
1105-
evp_cipher,
1106-
password,
1107-
len(password),
1108-
self._ffi.NULL,
1109-
self._ffi.NULL,
1110-
)
1111-
1112-
def _bio_func_output(self, write_bio, *args) -> bytes:
1113-
bio = self._create_mem_bio_gc()
1114-
res = write_bio(bio, *args)
1115-
self.openssl_assert(res == 1)
1116-
return self._read_mem_bio(bio)
1117-
1118-
def _public_key_bytes(
1119-
self,
1120-
encoding: serialization.Encoding,
1121-
format: serialization.PublicFormat,
1122-
key,
1123-
evp_pkey,
1124-
cdata,
1125-
) -> bytes:
1126-
if not isinstance(encoding, serialization.Encoding):
1127-
raise TypeError("encoding must be an item from the Encoding enum")
1128-
if not isinstance(format, serialization.PublicFormat):
1129-
raise TypeError(
1130-
"format must be an item from the PublicFormat enum"
1131-
)
1132-
1133-
# SubjectPublicKeyInfo + PEM/DER
1134-
if format is serialization.PublicFormat.SubjectPublicKeyInfo:
1135-
if encoding is serialization.Encoding.PEM:
1136-
write_bio = self._lib.PEM_write_bio_PUBKEY
1137-
elif encoding is serialization.Encoding.DER:
1138-
write_bio = self._lib.i2d_PUBKEY_bio
1139-
else:
1140-
raise ValueError(
1141-
"SubjectPublicKeyInfo works only with PEM or DER encoding"
1142-
)
1143-
return self._bio_func_output(write_bio, evp_pkey)
1144-
1145-
# PKCS1 + PEM/DER
1146-
if format is serialization.PublicFormat.PKCS1:
1147-
# Only RSA is supported here.
1148-
key_type = self._lib.EVP_PKEY_id(evp_pkey)
1149-
self.openssl_assert(key_type == self._lib.EVP_PKEY_RSA)
1150-
1151-
if encoding is serialization.Encoding.PEM:
1152-
write_bio = self._lib.PEM_write_bio_RSAPublicKey
1153-
elif encoding is serialization.Encoding.DER:
1154-
write_bio = self._lib.i2d_RSAPublicKey_bio
1155-
else:
1156-
raise ValueError("PKCS1 works only with PEM or DER encoding")
1157-
return self._bio_func_output(write_bio, cdata)
1158-
1159-
# OpenSSH + OpenSSH
1160-
if format is serialization.PublicFormat.OpenSSH:
1161-
if encoding is serialization.Encoding.OpenSSH:
1162-
return ssh.serialize_ssh_public_key(key)
1163-
1164-
raise ValueError(
1165-
"OpenSSH format must be used with OpenSSH encoding"
1166-
)
1167-
1168-
# Anything that key-specific code was supposed to handle earlier,
1169-
# like Raw, CompressedPoint, UncompressedPoint
1170-
raise ValueError("format is invalid with this key")
1171-
1172932
def dh_supported(self) -> bool:
1173933
return not self._lib.CRYPTOGRAPHY_IS_BORINGSSL
1174934

0 commit comments

Comments
 (0)