diff --git a/src/libraries/Common/src/Interop/Android/System.Security.Cryptography.Native.Android/Interop.Ssl.cs b/src/libraries/Common/src/Interop/Android/System.Security.Cryptography.Native.Android/Interop.Ssl.cs index 133dc3ec445d13..b9aed3973f845a 100644 --- a/src/libraries/Common/src/Interop/Android/System.Security.Cryptography.Native.Android/Interop.Ssl.cs +++ b/src/libraries/Common/src/Interop/Android/System.Security.Cryptography.Native.Android/Interop.Ssl.cs @@ -2,13 +2,134 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; + using Microsoft.Win32.SafeHandles; +using SafeSslHandle = System.Net.SafeSslHandle; + internal static partial class Interop { internal static partial class AndroidCrypto { + internal unsafe delegate PAL_SSLStreamStatus SSLReadCallback(byte* data, int* length); + internal unsafe delegate void SSLWriteCallback(byte* data, int length); + + internal enum PAL_SSLStreamStatus + { + OK = 0, + NeedData = 1, + Error = 2, + Renegotiate = 3, + Closed = 4, + }; + + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamCreate")] + internal static extern SafeSslHandle SSLStreamCreate(); + + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamCreateWithCertificates")] + private static extern SafeSslHandle SSLStreamCreateWithCertificates( + ref byte pkcs8PrivateKey, + int pkcs8PrivateKeyLen, + PAL_KeyAlgorithm algorithm, + IntPtr[] certs, + int certsLen); + internal static SafeSslHandle SSLStreamCreateWithCertificates(ReadOnlySpan pkcs8PrivateKey, PAL_KeyAlgorithm algorithm, IntPtr[] certificates) + { + return SSLStreamCreateWithCertificates( + ref MemoryMarshal.GetReference(pkcs8PrivateKey), + pkcs8PrivateKey.Length, + algorithm, + certificates, + certificates.Length); + } + + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamInitialize")] + private static extern int SSLStreamInitializeImpl( + SafeSslHandle sslHandle, + [MarshalAs(UnmanagedType.U1)] bool isServer, + SSLReadCallback streamRead, + SSLWriteCallback streamWrite, + int appBufferSize); + internal static void SSLStreamInitialize( + SafeSslHandle sslHandle, + bool isServer, + SSLReadCallback streamRead, + SSLWriteCallback streamWrite, + int appBufferSize) + { + int ret = SSLStreamInitializeImpl(sslHandle, isServer, streamRead, streamWrite, appBufferSize); + if (ret != SUCCESS) + throw new SslException(); + } + + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamConfigureParameters")] + private static extern int SSLStreamConfigureParametersImpl( + SafeSslHandle sslHandle, + [MarshalAs(UnmanagedType.LPUTF8Str)] string targetHost); + internal static void SSLStreamConfigureParameters( + SafeSslHandle sslHandle, + string targetHost) + { + int ret = SSLStreamConfigureParametersImpl(sslHandle, targetHost); + if (ret != SUCCESS) + throw new SslException(); + } + + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamHandshake")] + internal static extern PAL_SSLStreamStatus SSLStreamHandshake(SafeSslHandle sslHandle); + + [DllImport(Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamGetApplicationProtocol")] + private static extern int SSLStreamGetApplicationProtocol(SafeSslHandle ssl, [Out] byte[]? buf, ref int len); + internal static byte[]? SSLStreamGetApplicationProtocol(SafeSslHandle ssl) + { + int len = 0; + int ret = SSLStreamGetApplicationProtocol(ssl, null, ref len); + if (ret != INSUFFICIENT_BUFFER) + return null; + + byte[] bytes = new byte[len]; + ret = SSLStreamGetApplicationProtocol(ssl, bytes, ref len); + if (ret != SUCCESS) + return null; + + return bytes; + } + + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamRead")] + private static unsafe extern PAL_SSLStreamStatus SSLStreamRead( + SafeSslHandle sslHandle, + byte* buffer, + int length, + out int bytesRead); + internal static unsafe PAL_SSLStreamStatus SSLStreamRead( + SafeSslHandle sslHandle, + Span buffer, + out int bytesRead) + { + fixed (byte* bufferPtr = buffer) + { + return SSLStreamRead(sslHandle, bufferPtr, buffer.Length, out bytesRead); + } + } + + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamWrite")] + private static unsafe extern PAL_SSLStreamStatus SSLStreamWrite( + SafeSslHandle sslHandle, + byte* buffer, + int length); + internal static unsafe PAL_SSLStreamStatus SSLStreamWrite( + SafeSslHandle sslHandle, + ReadOnlyMemory buffer) + { + using (MemoryHandle memHandle = buffer.Pin()) + { + return SSLStreamWrite(sslHandle, (byte*)memHandle.Pointer, buffer.Length); + } + } + [DllImport(Interop.Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamRelease")] internal static extern void SSLStreamRelease(IntPtr ptr); @@ -23,6 +144,78 @@ internal SslException(int errorCode) HResult = errorCode; } } + + [DllImport(Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamGetProtocol")] + private static extern int SSLStreamGetProtocol(SafeSslHandle ssl, out IntPtr protocol); + internal static string SSLStreamGetProtocol(SafeSslHandle ssl) + { + IntPtr protocolPtr; + int ret = SSLStreamGetProtocol(ssl, out protocolPtr); + if (ret != SUCCESS) + throw new SslException(); + + if (protocolPtr == IntPtr.Zero) + return string.Empty; + + string protocol = Marshal.PtrToStringUni(protocolPtr)!; + Marshal.FreeHGlobal(protocolPtr); + return protocol; + } + + [DllImport(Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamGetPeerCertificate")] + private static extern int SSLStreamGetPeerCertificate(SafeSslHandle ssl, out SafeX509Handle cert); + internal static SafeX509Handle SSLStreamGetPeerCertificate(SafeSslHandle ssl) + { + SafeX509Handle cert; + int ret = Interop.AndroidCrypto.SSLStreamGetPeerCertificate(ssl, out cert); + if (ret != SUCCESS) + throw new SslException(); + + return cert; + } + + [DllImport(Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamGetPeerCertificates")] + private static extern int SSLStreamGetPeerCertificates( + SafeSslHandle ssl, + [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 2)] out IntPtr[] certs, + out int count); + internal static IntPtr[] SSLStreamGetPeerCertificates(SafeSslHandle ssl) + { + IntPtr[] ptrs; + int count; + int ret = Interop.AndroidCrypto.SSLStreamGetPeerCertificates(ssl, out ptrs, out count); + if (ret != SUCCESS) + throw new SslException(); + + return ptrs; + } + + [DllImport(Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamGetCipherSuite")] + private static extern int SSLStreamGetCipherSuite(SafeSslHandle ssl, out IntPtr cipherSuite); + internal static string SSLStreamGetCipherSuite(SafeSslHandle ssl) + { + IntPtr cipherSuitePtr; + int ret = SSLStreamGetCipherSuite(ssl, out cipherSuitePtr); + if (ret != SUCCESS) + throw new SslException(); + + if (cipherSuitePtr == IntPtr.Zero) + return string.Empty; + + string cipherSuite = Marshal.PtrToStringUni(cipherSuitePtr)!; + Marshal.FreeHGlobal(cipherSuitePtr); + return cipherSuite; + } + + [DllImport(Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamShutdown")] + [return: MarshalAs(UnmanagedType.U1)] + internal static extern bool SSLStreamShutdown(SafeSslHandle ssl); + + [DllImport(Libraries.CryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamVerifyHostname")] + [return: MarshalAs(UnmanagedType.U1)] + internal static extern bool SSLStreamVerifyHostname( + SafeSslHandle ssl, + [MarshalAs(UnmanagedType.LPUTF8Str)] string hostname); } } diff --git a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.c b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.c index f5f6819e9afbfd..dea66dc52d4f1d 100644 --- a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.c +++ b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.c @@ -14,6 +14,10 @@ jmethodID g_ByteArrayInputStreamReset; jclass g_Enum; jmethodID g_EnumOrdinal; +// java/lang/String +jclass g_String; +jmethodID g_StringGetBytes; + // java/lang/Throwable jclass g_ThrowableClass; jmethodID g_ThrowableGetCause; @@ -73,8 +77,10 @@ jmethodID g_bitLengthMethod; jmethodID g_sigNumMethod; // javax/net/ssl/SSLParameters -jclass g_sslParamsClass; -jmethodID g_sslParamsGetProtocolsMethod; +jclass g_SSLParametersClass; +jmethodID g_SSLParametersCtor; +jmethodID g_SSLParametersGetProtocols; +jmethodID g_SSLParametersSetServerNames; // javax/net/ssl/SSLContext jclass g_sslCtxClass; @@ -114,6 +120,7 @@ jmethodID g_keyPairGenGenKeyPairMethod; // java/security/KeyStore jclass g_KeyStoreClass; +jmethodID g_KeyStoreGetDefaultType; jmethodID g_KeyStoreGetInstance; jmethodID g_KeyStoreAliases; jmethodID g_KeyStoreContainsAlias; @@ -322,6 +329,10 @@ jmethodID g_EllipticCurveGetB; jmethodID g_EllipticCurveGetField; jmethodID g_EllipticCurveGetSeed; +// java/security/spec/PKCS8EncodedKeySpec +jclass g_PKCS8EncodedKeySpec; +jmethodID g_PKCS8EncodedKeySpecCtor; + // java/security/spec/X509EncodedKeySpec jclass g_X509EncodedKeySpecClass; jmethodID g_X509EncodedKeySpecCtor; @@ -365,49 +376,68 @@ jmethodID g_IteratorNext; jclass g_ListClass; jmethodID g_ListGet; +// javax/net/ssl/HostnameVerifier +jclass g_HostnameVerifier; +jmethodID g_HostnameVerifierVerify; + +// javax/net/ssl/HttpsURLConnection +jclass g_HttpsURLConnection; +jmethodID g_HttpsURLConnectionGetDefaultHostnameVerifier; + +// javax/net/ssl/KeyManagerFactory +jclass g_KeyManagerFactory; +jmethodID g_KeyManagerFactoryGetInstance; +jmethodID g_KeyManagerFactoryInit; +jmethodID g_KeyManagerFactoryGetKeyManagers; + +// javax/net/ssl/SNIHostName +jclass g_SNIHostName; +jmethodID g_SNIHostNameCtor; + // javax/net/ssl/SSLEngine jclass g_SSLEngine; -jmethodID g_SSLEngineSetUseClientModeMethod; -jmethodID g_SSLEngineGetSessionMethod; -jmethodID g_SSLEngineBeginHandshakeMethod; -jmethodID g_SSLEngineWrapMethod; -jmethodID g_SSLEngineUnwrapMethod; -jmethodID g_SSLEngineCloseInboundMethod; -jmethodID g_SSLEngineCloseOutboundMethod; -jmethodID g_SSLEngineGetHandshakeStatusMethod; +jmethodID g_SSLEngineGetApplicationProtocol; +jmethodID g_SSLEngineSetUseClientMode; +jmethodID g_SSLEngineGetSession; +jmethodID g_SSLEngineBeginHandshake; +jmethodID g_SSLEngineWrap; +jmethodID g_SSLEngineUnwrap; +jmethodID g_SSLEngineCloseOutbound; +jmethodID g_SSLEngineGetHandshakeStatus; +jmethodID g_SSLEngineSetSSLParameters; // java/nio/ByteBuffer jclass g_ByteBuffer; -jmethodID g_ByteBufferAllocateMethod; -jmethodID g_ByteBufferPutMethod; -jmethodID g_ByteBufferPut2Method; -jmethodID g_ByteBufferPut3Method; -jmethodID g_ByteBufferFlipMethod; -jmethodID g_ByteBufferGetMethod; -jmethodID g_ByteBufferPutBufferMethod; -jmethodID g_ByteBufferLimitMethod; -jmethodID g_ByteBufferRemainingMethod; -jmethodID g_ByteBufferCompactMethod; -jmethodID g_ByteBufferPositionMethod; +jmethodID g_ByteBufferAllocate; +jmethodID g_ByteBufferCompact; +jmethodID g_ByteBufferFlip; +jmethodID g_ByteBufferGet; +jmethodID g_ByteBufferLimit; +jmethodID g_ByteBufferPosition; +jmethodID g_ByteBufferPutBuffer; +jmethodID g_ByteBufferPutByteArray; +jmethodID g_ByteBufferPutByteArrayWithLength; +jmethodID g_ByteBufferRemaining; // javax/net/ssl/SSLContext jclass g_SSLContext; +jmethodID g_SSLContextGetDefault; jmethodID g_SSLContextGetInstanceMethod; jmethodID g_SSLContextInitMethod; jmethodID g_SSLContextCreateSSLEngineMethod; // javax/net/ssl/SSLSession jclass g_SSLSession; -jmethodID g_SSLSessionGetApplicationBufferSizeMethod; -jmethodID g_SSLSessionGetPacketBufferSizeMethod; +jmethodID g_SSLSessionGetApplicationBufferSize; +jmethodID g_SSLSessionGetCipherSuite; +jmethodID g_SSLSessionGetPacketBufferSize; +jmethodID g_SSLSessionGetPeerCertificates; +jmethodID g_SSLSessionGetProtocol; // javax/net/ssl/SSLEngineResult jclass g_SSLEngineResult; -jmethodID g_SSLEngineResultGetStatusMethod; -jmethodID g_SSLEngineResultGetHandshakeStatusMethod; - -// javax/net/ssl/TrustManager -jclass g_TrustManager; +jmethodID g_SSLEngineResultGetStatus; +jmethodID g_SSLEngineResultGetHandshakeStatus; // javax/crypto/KeyAgreement jclass g_KeyAgreementClass; @@ -518,11 +548,6 @@ bool TryGetJNIException(JNIEnv* env, jthrowable *ex, bool printException) return true; } -void AssertOnJNIExceptions(JNIEnv* env) -{ - assert(!CheckJNIExceptions(env)); -} - void SaveTo(uint8_t* src, uint8_t** dst, size_t len, bool overwrite) { assert(overwrite || !(*dst)); @@ -602,6 +627,9 @@ JNI_OnLoad(JavaVM *vm, void *reserved) g_Enum = GetClassGRef(env, "java/lang/Enum"); g_EnumOrdinal = GetMethod(env, false, g_Enum, "ordinal", "()I"); + g_String = GetClassGRef(env, "java/lang/String"); + g_StringGetBytes = GetMethod(env, false, g_String, "getBytes", "()[B"); + g_ThrowableClass = GetClassGRef(env, "java/lang/Throwable"); g_ThrowableGetCause = GetMethod(env, false, g_ThrowableClass, "getCause", "()Ljava/lang/Throwable;"); g_ThrowableGetMessage = GetMethod(env, false, g_ThrowableClass, "getMessage", "()Ljava/lang/String;"); @@ -655,8 +683,10 @@ JNI_OnLoad(JavaVM *vm, void *reserved) g_bitLengthMethod = GetMethod(env, false, g_bigNumClass, "bitLength", "()I"); g_sigNumMethod = GetMethod(env, false, g_bigNumClass, "signum", "()I"); - g_sslParamsClass = GetClassGRef(env, "javax/net/ssl/SSLParameters"); - g_sslParamsGetProtocolsMethod = GetMethod(env, false, g_sslParamsClass, "getProtocols", "()[Ljava/lang/String;"); + g_SSLParametersClass = GetClassGRef(env, "javax/net/ssl/SSLParameters"); + g_SSLParametersCtor = GetMethod(env, false, g_SSLParametersClass, "", "()V"); + g_SSLParametersGetProtocols = GetMethod(env, false, g_SSLParametersClass, "getProtocols", "()[Ljava/lang/String;"); + g_SSLParametersSetServerNames = GetMethod(env, false, g_SSLParametersClass, "setServerNames", "(Ljava/util/List;)V"); g_sslCtxClass = GetClassGRef(env, "javax/net/ssl/SSLContext"); g_sslCtxGetDefaultMethod = GetMethod(env, true, g_sslCtxClass, "getDefault", "()Ljavax/net/ssl/SSLContext;"); @@ -752,6 +782,7 @@ JNI_OnLoad(JavaVM *vm, void *reserved) g_keyPairGenGenKeyPairMethod = GetMethod(env, false, g_keyPairGenClass, "genKeyPair", "()Ljava/security/KeyPair;"); g_KeyStoreClass = GetClassGRef(env, "java/security/KeyStore"); + g_KeyStoreGetDefaultType = GetMethod(env, true, g_KeyStoreClass, "getDefaultType", "()Ljava/lang/String;"); g_KeyStoreGetInstance = GetMethod(env, true, g_KeyStoreClass, "getInstance", "(Ljava/lang/String;)Ljava/security/KeyStore;"); g_KeyStoreAliases = GetMethod(env, false, g_KeyStoreClass, "aliases", "()Ljava/util/Enumeration;"); g_KeyStoreContainsAlias = GetMethod(env, false, g_KeyStoreClass, "containsAlias", "(Ljava/lang/String;)Z"); @@ -859,6 +890,9 @@ JNI_OnLoad(JavaVM *vm, void *reserved) g_EllipticCurveGetField = GetMethod(env, false, g_EllipticCurveClass, "getField", "()Ljava/security/spec/ECField;"); g_EllipticCurveGetSeed = GetMethod(env, false, g_EllipticCurveClass, "getSeed", "()[B"); + g_PKCS8EncodedKeySpec = GetClassGRef(env, "java/security/spec/PKCS8EncodedKeySpec"); + g_PKCS8EncodedKeySpecCtor = GetMethod(env, false, g_PKCS8EncodedKeySpec, "", "([B)V"); + g_X509EncodedKeySpecClass = GetClassGRef(env, "java/security/spec/X509EncodedKeySpec"); g_X509EncodedKeySpecCtor = GetMethod(env, false, g_X509EncodedKeySpecClass, "", "([B)V"); @@ -893,43 +927,59 @@ JNI_OnLoad(JavaVM *vm, void *reserved) g_ListClass = GetClassGRef(env, "java/util/List"); g_ListGet = GetMethod(env, false, g_ListClass, "get", "(I)Ljava/lang/Object;"); - g_SSLEngine = GetClassGRef(env, "javax/net/ssl/SSLEngine"); - g_SSLEngineSetUseClientModeMethod = GetMethod(env, false, g_SSLEngine, "setUseClientMode", "(Z)V"); - g_SSLEngineGetSessionMethod = GetMethod(env, false, g_SSLEngine, "getSession", "()Ljavax/net/ssl/SSLSession;"); - g_SSLEngineBeginHandshakeMethod = GetMethod(env, false, g_SSLEngine, "beginHandshake", "()V"); - g_SSLEngineWrapMethod = GetMethod(env, false, g_SSLEngine, "wrap", "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)Ljavax/net/ssl/SSLEngineResult;"); - g_SSLEngineUnwrapMethod = GetMethod(env, false, g_SSLEngine, "unwrap", "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)Ljavax/net/ssl/SSLEngineResult;"); - g_SSLEngineGetHandshakeStatusMethod = GetMethod(env, false, g_SSLEngine, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;"); - g_SSLEngineCloseInboundMethod = GetMethod(env, false, g_SSLEngine, "closeInbound", "()V"); - g_SSLEngineCloseOutboundMethod = GetMethod(env, false, g_SSLEngine, "closeOutbound", "()V"); - - g_ByteBuffer = GetClassGRef(env, "java/nio/ByteBuffer"); - g_ByteBufferAllocateMethod = GetMethod(env, true, g_ByteBuffer, "allocate", "(I)Ljava/nio/ByteBuffer;"); - g_ByteBufferPutMethod = GetMethod(env, false, g_ByteBuffer, "put", "(Ljava/nio/ByteBuffer;)Ljava/nio/ByteBuffer;"); - g_ByteBufferPut2Method = GetMethod(env, false, g_ByteBuffer, "put", "([B)Ljava/nio/ByteBuffer;"); - g_ByteBufferPut3Method = GetMethod(env, false, g_ByteBuffer, "put", "([BII)Ljava/nio/ByteBuffer;"); - g_ByteBufferFlipMethod = GetMethod(env, false, g_ByteBuffer, "flip", "()Ljava/nio/Buffer;"); - g_ByteBufferLimitMethod = GetMethod(env, false, g_ByteBuffer, "limit", "()I"); - g_ByteBufferGetMethod = GetMethod(env, false, g_ByteBuffer, "get", "([B)Ljava/nio/ByteBuffer;"); - g_ByteBufferPutBufferMethod = GetMethod(env, false, g_ByteBuffer, "put", "(Ljava/nio/ByteBuffer;)Ljava/nio/ByteBuffer;"); - g_ByteBufferRemainingMethod = GetMethod(env, false, g_ByteBuffer, "remaining", "()I"); - g_ByteBufferCompactMethod = GetMethod(env, false, g_ByteBuffer, "compact", "()Ljava/nio/ByteBuffer;"); - g_ByteBufferPositionMethod = GetMethod(env, false, g_ByteBuffer, "position", "()I"); + g_HostnameVerifier = GetClassGRef(env, "javax/net/ssl/HostnameVerifier"); + g_HostnameVerifierVerify = GetMethod(env, false, g_HostnameVerifier, "verify", "(Ljava/lang/String;Ljavax/net/ssl/SSLSession;)Z"); + + g_HttpsURLConnection = GetClassGRef(env, "javax/net/ssl/HttpsURLConnection"); + g_HttpsURLConnectionGetDefaultHostnameVerifier = GetMethod(env, true, g_HttpsURLConnection, "getDefaultHostnameVerifier", "()Ljavax/net/ssl/HostnameVerifier;"); + + g_KeyManagerFactory = GetClassGRef(env, "javax/net/ssl/KeyManagerFactory"); + g_KeyManagerFactoryGetInstance = GetMethod(env, true, g_KeyManagerFactory, "getInstance", "(Ljava/lang/String;)Ljavax/net/ssl/KeyManagerFactory;"); + g_KeyManagerFactoryInit = GetMethod(env, false, g_KeyManagerFactory, "init", "(Ljava/security/KeyStore;[C)V"); + g_KeyManagerFactoryGetKeyManagers = GetMethod(env, false, g_KeyManagerFactory, "getKeyManagers", "()[Ljavax/net/ssl/KeyManager;"); + + g_SNIHostName = GetClassGRef(env, "javax/net/ssl/SNIHostName"); + g_SNIHostNameCtor = GetMethod(env, false, g_SNIHostName, "", "(Ljava/lang/String;)V"); + + g_SSLEngine = GetClassGRef(env, "javax/net/ssl/SSLEngine"); + g_SSLEngineGetApplicationProtocol = GetMethod(env, false, g_SSLEngine, "getApplicationProtocol", "()Ljava/lang/String;"); + g_SSLEngineSetUseClientMode = GetMethod(env, false, g_SSLEngine, "setUseClientMode", "(Z)V"); + g_SSLEngineGetSession = GetMethod(env, false, g_SSLEngine, "getSession", "()Ljavax/net/ssl/SSLSession;"); + g_SSLEngineBeginHandshake = GetMethod(env, false, g_SSLEngine, "beginHandshake", "()V"); + g_SSLEngineWrap = GetMethod(env, false, g_SSLEngine, "wrap", "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)Ljavax/net/ssl/SSLEngineResult;"); + g_SSLEngineUnwrap = GetMethod(env, false, g_SSLEngine, "unwrap", "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)Ljavax/net/ssl/SSLEngineResult;"); + g_SSLEngineGetHandshakeStatus = GetMethod(env, false, g_SSLEngine, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;"); + g_SSLEngineCloseOutbound = GetMethod(env, false, g_SSLEngine, "closeOutbound", "()V"); + g_SSLEngineSetSSLParameters = GetMethod(env, false, g_SSLEngine, "setSSLParameters", "(Ljavax/net/ssl/SSLParameters;)V"); + + g_ByteBuffer = GetClassGRef(env, "java/nio/ByteBuffer"); + g_ByteBufferAllocate = GetMethod(env, true, g_ByteBuffer, "allocate", "(I)Ljava/nio/ByteBuffer;"); + g_ByteBufferCompact = GetMethod(env, false, g_ByteBuffer, "compact", "()Ljava/nio/ByteBuffer;"); + g_ByteBufferFlip = GetMethod(env, false, g_ByteBuffer, "flip", "()Ljava/nio/Buffer;"); + g_ByteBufferGet = GetMethod(env, false, g_ByteBuffer, "get", "([B)Ljava/nio/ByteBuffer;"); + g_ByteBufferLimit = GetMethod(env, false, g_ByteBuffer, "limit", "()I"); + g_ByteBufferPosition = GetMethod(env, false, g_ByteBuffer, "position", "()I"); + g_ByteBufferPutBuffer = GetMethod(env, false, g_ByteBuffer, "put", "(Ljava/nio/ByteBuffer;)Ljava/nio/ByteBuffer;"); + g_ByteBufferPutByteArray = GetMethod(env, false, g_ByteBuffer, "put", "([B)Ljava/nio/ByteBuffer;"); + g_ByteBufferPutByteArrayWithLength = GetMethod(env, false, g_ByteBuffer, "put", "([BII)Ljava/nio/ByteBuffer;"); + g_ByteBufferRemaining = GetMethod(env, false, g_ByteBuffer, "remaining", "()I"); g_SSLContext = GetClassGRef(env, "javax/net/ssl/SSLContext"); + g_SSLContextGetDefault = GetMethod(env, true, g_SSLContext, "getDefault", "()Ljavax/net/ssl/SSLContext;"); g_SSLContextGetInstanceMethod = GetMethod(env, true, g_SSLContext, "getInstance", "(Ljava/lang/String;)Ljavax/net/ssl/SSLContext;"); g_SSLContextInitMethod = GetMethod(env, false, g_SSLContext, "init", "([Ljavax/net/ssl/KeyManager;[Ljavax/net/ssl/TrustManager;Ljava/security/SecureRandom;)V"); g_SSLContextCreateSSLEngineMethod = GetMethod(env, false, g_SSLContext, "createSSLEngine", "()Ljavax/net/ssl/SSLEngine;"); - g_SSLSession = GetClassGRef(env, "javax/net/ssl/SSLSession"); - g_SSLSessionGetApplicationBufferSizeMethod = GetMethod(env, false, g_SSLSession, "getApplicationBufferSize", "()I"); - g_SSLSessionGetPacketBufferSizeMethod = GetMethod(env, false, g_SSLSession, "getPacketBufferSize", "()I"); - - g_SSLEngineResult = GetClassGRef(env, "javax/net/ssl/SSLEngineResult"); - g_SSLEngineResultGetStatusMethod = GetMethod(env, false, g_SSLEngineResult, "getStatus", "()Ljavax/net/ssl/SSLEngineResult$Status;"); - g_SSLEngineResultGetHandshakeStatusMethod = GetMethod(env, false, g_SSLEngineResult, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;"); + g_SSLSession = GetClassGRef(env, "javax/net/ssl/SSLSession"); + g_SSLSessionGetApplicationBufferSize = GetMethod(env, false, g_SSLSession, "getApplicationBufferSize", "()I"); + g_SSLSessionGetCipherSuite = GetMethod(env, false, g_SSLSession, "getCipherSuite", "()Ljava/lang/String;"); + g_SSLSessionGetPacketBufferSize = GetMethod(env, false, g_SSLSession, "getPacketBufferSize", "()I"); + g_SSLSessionGetPeerCertificates = GetMethod(env, false, g_SSLSession, "getPeerCertificates", "()[Ljava/security/cert/Certificate;"); + g_SSLSessionGetProtocol = GetMethod(env, false, g_SSLSession, "getProtocol", "()Ljava/lang/String;"); - g_TrustManager = GetClassGRef(env, "javax/net/ssl/TrustManager"); + g_SSLEngineResult = GetClassGRef(env, "javax/net/ssl/SSLEngineResult"); + g_SSLEngineResultGetStatus = GetMethod(env, false, g_SSLEngineResult, "getStatus", "()Ljavax/net/ssl/SSLEngineResult$Status;"); + g_SSLEngineResultGetHandshakeStatus = GetMethod(env, false, g_SSLEngineResult, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;"); g_KeyAgreementClass = GetClassGRef(env, "javax/crypto/KeyAgreement"); g_KeyAgreementGetInstance = GetMethod(env, true, g_KeyAgreementClass, "getInstance", "(Ljava/lang/String;)Ljavax/crypto/KeyAgreement;"); diff --git a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.h b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.h index 1414e740ad3af5..6095640415b864 100644 --- a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.h +++ b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_jni.h @@ -11,6 +11,7 @@ #define FAIL 0 #define SUCCESS 1 +#define INSUFFICIENT_BUFFER -1 extern JavaVM* gJvm; @@ -23,6 +24,10 @@ extern jmethodID g_ByteArrayInputStreamReset; extern jclass g_Enum; extern jmethodID g_EnumOrdinal; +// java/lang/String +extern jclass g_String; +extern jmethodID g_StringGetBytes; + // java/lang/Throwable extern jclass g_ThrowableClass; extern jmethodID g_ThrowableGetCause; @@ -82,8 +87,10 @@ extern jmethodID g_bitLengthMethod; extern jmethodID g_sigNumMethod; // javax/net/ssl/SSLParameters -extern jclass g_sslParamsClass; -extern jmethodID g_sslParamsGetProtocolsMethod; +extern jclass g_SSLParametersClass; +extern jmethodID g_SSLParametersCtor; +extern jmethodID g_SSLParametersGetProtocols; +extern jmethodID g_SSLParametersSetServerNames; // javax/net/ssl/SSLContext extern jclass g_sslCtxClass; @@ -204,6 +211,7 @@ extern jmethodID g_keyPairGenGenKeyPairMethod; // java/security/KeyStore extern jclass g_KeyStoreClass; +extern jmethodID g_KeyStoreGetDefaultType; extern jmethodID g_KeyStoreGetInstance; extern jmethodID g_KeyStoreAliases; extern jmethodID g_KeyStoreContainsAlias; @@ -331,6 +339,10 @@ extern jmethodID g_EllipticCurveGetB; extern jmethodID g_EllipticCurveGetField; extern jmethodID g_EllipticCurveGetSeed; +// java/security/spec/PKCS8EncodedKeySpec +extern jclass g_PKCS8EncodedKeySpec; +extern jmethodID g_PKCS8EncodedKeySpecCtor; + // java/security/spec/X509EncodedKeySpec extern jclass g_X509EncodedKeySpecClass; extern jmethodID g_X509EncodedKeySpecCtor; @@ -351,6 +363,11 @@ extern jclass g_CollectionClass; extern jmethodID g_CollectionIterator; extern jmethodID g_CollectionSize; +// java/util/ArrayList +extern jclass g_ArrayList; +extern jmethodID g_ArrayListCtor; +extern jmethodID g_ArrayListAdd; + // java/util/Date extern jclass g_DateClass; extern jmethodID g_DateCtor; @@ -374,49 +391,69 @@ extern jmethodID g_IteratorNext; extern jclass g_ListClass; extern jmethodID g_ListGet; +// javax/net/ssl/HostnameVerifier +extern jclass g_HostnameVerifier; +extern jmethodID g_HostnameVerifierVerify; + +// javax/net/ssl/HttpsURLConnection +extern jclass g_HttpsURLConnection; +extern jmethodID g_HttpsURLConnectionGetDefaultHostnameVerifier; + +// javax/net/ssl/KeyManagerFactory +extern jclass g_KeyManagerFactory; +extern jmethodID g_KeyManagerFactoryGetInstance; +extern jmethodID g_KeyManagerFactoryInit; +extern jmethodID g_KeyManagerFactoryGetKeyManagers; + +// javax/net/ssl/SNIHostName +extern jclass g_SNIHostName; +extern jmethodID g_SNIHostNameCtor; + // javax/net/ssl/SSLEngine extern jclass g_SSLEngine; -extern jmethodID g_SSLEngineSetUseClientModeMethod; -extern jmethodID g_SSLEngineGetSessionMethod; -extern jmethodID g_SSLEngineBeginHandshakeMethod; -extern jmethodID g_SSLEngineWrapMethod; -extern jmethodID g_SSLEngineUnwrapMethod; -extern jmethodID g_SSLEngineCloseInboundMethod; -extern jmethodID g_SSLEngineCloseOutboundMethod; -extern jmethodID g_SSLEngineGetHandshakeStatusMethod; +extern jmethodID g_SSLEngineGetApplicationProtocol; +extern jmethodID g_SSLEngineSetUseClientMode; +extern jmethodID g_SSLEngineGetSession; +extern jmethodID g_SSLEngineBeginHandshake; +extern jmethodID g_SSLEngineWrap; +extern jmethodID g_SSLEngineUnwrap; +extern jmethodID g_SSLEngineCloseOutbound; +extern jmethodID g_SSLEngineGetHandshakeStatus; +extern jmethodID g_SSLEngineSetSSLParameters; // java/nio/ByteBuffer extern jclass g_ByteBuffer; -extern jmethodID g_ByteBufferAllocateMethod; -extern jmethodID g_ByteBufferPutMethod; -extern jmethodID g_ByteBufferPut2Method; -extern jmethodID g_ByteBufferPut3Method; -extern jmethodID g_ByteBufferFlipMethod; -extern jmethodID g_ByteBufferGetMethod; -extern jmethodID g_ByteBufferLimitMethod; -extern jmethodID g_ByteBufferRemainingMethod; -extern jmethodID g_ByteBufferPutBufferMethod; -extern jmethodID g_ByteBufferCompactMethod; -extern jmethodID g_ByteBufferPositionMethod; +extern jmethodID g_ByteBufferAllocate; +extern jmethodID g_ByteBufferCompact; +extern jmethodID g_ByteBufferFlip; +extern jmethodID g_ByteBufferGet; +extern jmethodID g_ByteBufferLimit; +extern jmethodID g_ByteBufferPosition; +extern jmethodID g_ByteBufferPutBuffer; +extern jmethodID g_ByteBufferPutByteArray; +extern jmethodID g_ByteBufferPutByteArrayWithLength; +extern jmethodID g_ByteBufferRemaining; // javax/net/ssl/SSLContext extern jclass g_SSLContext; +extern jmethodID g_SSLContextGetDefault; extern jmethodID g_SSLContextGetInstanceMethod; extern jmethodID g_SSLContextInitMethod; extern jmethodID g_SSLContextCreateSSLEngineMethod; +extern jmethodID g_SSLContextCreateSSLEngineWithPeer; // javax/net/ssl/SSLSession extern jclass g_SSLSession; -extern jmethodID g_SSLSessionGetApplicationBufferSizeMethod; -extern jmethodID g_SSLSessionGetPacketBufferSizeMethod; +extern jmethodID g_SSLSessionGetApplicationBufferSize; +extern jmethodID g_SSLSessionGetCipherSuite; +extern jmethodID g_SSLSessionGetPacketBufferSize; +extern jmethodID g_SSLSessionGetPeerCertificates; +extern jmethodID g_SSLSessionGetProtocol; // javax/net/ssl/SSLEngineResult extern jclass g_SSLEngineResult; -extern jmethodID g_SSLEngineResultGetStatusMethod; -extern jmethodID g_SSLEngineResultGetHandshakeStatusMethod; - -// javax/net/ssl/TrustManager -extern jclass g_TrustManager; +extern jmethodID g_SSLEngineResultGetStatus; +extern jmethodID g_SSLEngineResultGetHandshakeStatus; // javax/crypto/KeyAgreement extern jclass g_KeyAgreementClass; @@ -425,13 +462,18 @@ extern jmethodID g_KeyAgreementInit; extern jmethodID g_KeyAgreementDoPhase; extern jmethodID g_KeyAgreementGenerateSecret; -// JNI helpers +// Logging helpers #define LOG_DEBUG(fmt, ...) ((void)__android_log_print(ANDROID_LOG_DEBUG, "DOTNET", "%s: " fmt, __FUNCTION__, ## __VA_ARGS__)) #define LOG_INFO(fmt, ...) ((void)__android_log_print(ANDROID_LOG_INFO, "DOTNET", "%s: " fmt, __FUNCTION__, ## __VA_ARGS__)) #define LOG_ERROR(fmt, ...) ((void)__android_log_print(ANDROID_LOG_ERROR, "DOTNET", "%s: " fmt, __FUNCTION__, ## __VA_ARGS__)) + +// JNI helpers - assume there is a JNIEnv* variable named env #define JSTRING(str) ((jstring)(*env)->NewStringUTF(env, str)) #define ON_EXCEPTION_PRINT_AND_GOTO(label) if (CheckJNIExceptions(env)) goto label +// Explicitly ignore jobject return value +#define IGNORE_RETURN(retval) (*env)->DeleteLocalRef(env, retval) + #define INIT_LOCALS(name, ...) \ enum { __VA_ARGS__, count_##name }; \ jobject name[count_##name] = { 0 } \ @@ -470,9 +512,6 @@ bool TryClearJNIExceptions(JNIEnv* env); // Get any pending JNI exception. Returns true if there was an exception, false otherwise. bool TryGetJNIException(JNIEnv* env, jthrowable *ex, bool printException); -// Assert on any JNI exceptions. Prints the exception before asserting. -void AssertOnJNIExceptions(JNIEnv* env); - jmethodID GetMethod(JNIEnv *env, bool isStatic, jclass klass, const char* name, const char* sig); jmethodID GetOptionalMethod(JNIEnv *env, bool isStatic, jclass klass, const char* name, const char* sig); jfieldID GetField(JNIEnv *env, bool isStatic, jclass klass, const char* name, const char* sig); diff --git a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_ssl.c b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_ssl.c index fb7eac36c0b6f6..d2866f906ae383 100644 --- a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_ssl.c +++ b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_ssl.c @@ -8,7 +8,7 @@ int32_t CryptoNative_OpenSslGetProtocolSupport(SslProtocols protocol) JNIEnv* env = GetJNIEnv(); jobject sslCtxObj = (*env)->CallStaticObjectMethod(env, g_sslCtxClass, g_sslCtxGetDefaultMethod); jobject sslParametersObj = (*env)->CallObjectMethod(env, sslCtxObj, g_sslCtxGetDefaultSslParamsMethod); - jobjectArray protocols = (jobjectArray)(*env)->CallObjectMethod(env, sslParametersObj, g_sslParamsGetProtocolsMethod); + jobjectArray protocols = (jobjectArray)(*env)->CallObjectMethod(env, sslParametersObj, g_SSLParametersGetProtocols); int protocolsCount = (*env)->GetArrayLength(env, protocols); int supported = 0; diff --git a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.c b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.c index b8ac3849ad6bc7..6f1c8b238fc717 100644 --- a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.c +++ b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.c @@ -3,73 +3,77 @@ #include "pal_sslstream.h" -void checkHandshakeStatus(JNIEnv* env, SSLStream* sslStream, int handshakeStatus); - -static int getHandshakeStatus(JNIEnv* env, SSLStream* sslStream, jobject engineResult) +// javax/net/ssl/SSLEngineResult$HandshakeStatus +enum { - AssertOnJNIExceptions(env); - int status = -1; - if (engineResult) - status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, engineResult, g_SSLEngineResultGetHandshakeStatusMethod)); - else - status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetHandshakeStatusMethod)); - AssertOnJNIExceptions(env); - return status; -} + HANDSHAKE_STATUS__NOT_HANDSHAKING = 0, + HANDSHAKE_STATUS__FINISHED = 1, + HANDSHAKE_STATUS__NEED_TASK = 2, + HANDSHAKE_STATUS__NEED_WRAP = 3, + HANDSHAKE_STATUS__NEED_UNWRAP = 4, +}; + +// javax/net/ssl/SSLEngineResult$Status +enum +{ + STATUS__BUFFER_UNDERFLOW = 0, + STATUS__BUFFER_OVERFLOW = 1, + STATUS__OK = 2, + STATUS__CLOSED = 3, +}; -static void close(JNIEnv* env, SSLStream* sslStream) { - /* - sslEngine.closeOutbound(); - checkHandshakeStatus(); - */ +static uint16_t* AllocateString(JNIEnv* env, jstring source); + +static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* sslStream); +static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus); +static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus); - AssertOnJNIExceptions(env); - (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseOutboundMethod); - checkHandshakeStatus(env, sslStream, getHandshakeStatus(env, sslStream, NULL)); +static bool IsHandshaking(int handshakeStatus) +{ + return handshakeStatus != HANDSHAKE_STATUS__NOT_HANDSHAKING && handshakeStatus != HANDSHAKE_STATUS__FINISHED; } -static void handleEndOfStream(JNIEnv* env, SSLStream* sslStream) { - /* - sslEngine.closeInbound(); - close(); - */ +static PAL_SSLStreamStatus Close(JNIEnv* env, SSLStream* sslStream) +{ + // sslEngine.closeOutbound(); + (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseOutbound); - AssertOnJNIExceptions(env); - (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseInboundMethod); - close(env, sslStream); - AssertOnJNIExceptions(env); + // Call wrap to clear any remaining data + int unused; + return DoWrap(env, sslStream, &unused); } -static void flush(JNIEnv* env, SSLStream* sslStream) +static PAL_SSLStreamStatus Flush(JNIEnv* env, SSLStream* sslStream) { /* netOutBuffer.flip(); byte[] data = new byte[netOutBuffer.limit()]; netOutBuffer.get(data); - WriteToOutputStream(data, 0, data.length); + streamWriter(data, 0, data.length); netOutBuffer.compact(); */ - AssertOnJNIExceptions(env); - - // DeleteLocalRef because we don't need the return value (Buffer) - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->netOutBuffer, g_ByteBufferFlipMethod)); - int bufferLimit = (*env)->CallIntMethod(env, sslStream->netOutBuffer, g_ByteBufferLimitMethod); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->netOutBuffer, g_ByteBufferFlip)); + int bufferLimit = (*env)->CallIntMethod(env, sslStream->netOutBuffer, g_ByteBufferLimit); jbyteArray data = (*env)->NewByteArray(env, bufferLimit); - // DeleteLocalRef because we don't need the return value (Buffer) - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->netOutBuffer, g_ByteBufferGetMethod, data)); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->netOutBuffer, g_ByteBufferGet, data)); + if (CheckJNIExceptions(env)) + return SSLStreamStatus_Error; uint8_t* dataPtr = (uint8_t*)malloc((size_t)bufferLimit); - (*env)->GetByteArrayRegion(env, data, 0, bufferLimit, (jbyte*) dataPtr); - sslStream->streamWriter(dataPtr, 0, (uint32_t)bufferLimit); + (*env)->GetByteArrayRegion(env, data, 0, bufferLimit, (jbyte*)dataPtr); + sslStream->streamWriter(dataPtr, bufferLimit); free(dataPtr); - // DeleteLocalRef because we don't need the return value (Buffer) - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->netOutBuffer, g_ByteBufferCompactMethod)); - AssertOnJNIExceptions(env); + + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->netOutBuffer, g_ByteBufferCompact)); + if (CheckJNIExceptions(env)) + return SSLStreamStatus_Error; + + return SSLStreamStatus_OK; } -static jobject ensureRemaining(JNIEnv* env, SSLStream* sslStream, jobject oldBuffer, int newRemaining) +static jobject EnsureRemaining(JNIEnv* env, SSLStream* sslStream, jobject oldBuffer, int newRemaining) { /* if (oldBuffer.remaining() < newRemaining) { @@ -82,16 +86,14 @@ static jobject ensureRemaining(JNIEnv* env, SSLStream* sslStream, jobject oldBuf } */ - AssertOnJNIExceptions(env); - - int oldRemaining = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferRemainingMethod); + int oldRemaining = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferRemaining); if (oldRemaining < newRemaining) { - // DeleteLocalRef because we don't need the return value (Buffer) - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, oldBuffer, g_ByteBufferFlipMethod)); - jobject newBuffer = ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocateMethod, oldRemaining + newRemaining)); - // DeleteLocalRef because we don't need the return value (Buffer) - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, newBuffer, g_ByteBufferPutBufferMethod, oldBuffer)); + IGNORE_RETURN((*env)->CallObjectMethod(env, oldBuffer, g_ByteBufferFlip)); + jobject newBuffer = ToGRef( + env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocate, oldRemaining + newRemaining)); + + IGNORE_RETURN((*env)->CallObjectMethod(env, newBuffer, g_ByteBufferPutBuffer, oldBuffer)); ReleaseGRef(env, oldBuffer); return newBuffer; } @@ -101,314 +103,694 @@ static jobject ensureRemaining(JNIEnv* env, SSLStream* sslStream, jobject oldBuf } } -static void doWrap(JNIEnv* env, SSLStream* sslStream) +static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus) { - /* - appOutBuffer.flip(); - final SSLEngineResult result; - try { - result = sslEngine.wrap(appOutBuffer, netOutBuffer); - } catch (SSLException e) { - return; - } - appOutBuffer.compact(); - - final SSLEngineResult.Status status = result.getStatus(); - switch (status) { - case OK: - flush(); - checkHandshakeStatus(result.getHandshakeStatus()); - if (appOutBuffer.position() > 0) doWrap(); - break; - case CLOSED: - flush(); - checkHandshakeStatus(result.getHandshakeStatus()); - close(); - break; - case BUFFER_OVERFLOW: - netOutBuffer = ensureRemaining(netOutBuffer, sslEngine.getSession().getPacketBufferSize()); - doWrap(); - break; - } - */ - - AssertOnJNIExceptions(env); - - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlipMethod)); - jobject sslEngineResult = (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineWrapMethod, sslStream->appOutBuffer, sslStream->netOutBuffer); - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompactMethod)); - - int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, sslEngineResult, g_SSLEngineResultGetStatusMethod)); + // appOutBuffer.flip(); + // SSLEngineResult result = sslEngine.wrap(appOutBuffer, netOutBuffer); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip)); + jobject result = (*env)->CallObjectMethod( + env, sslStream->sslEngine, g_SSLEngineWrap, sslStream->appOutBuffer, sslStream->netOutBuffer); + if (CheckJNIExceptions(env)) + return SSLStreamStatus_Error; + + // appOutBuffer.compact(); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact)); + + // handshakeStatus = result.getHandshakeStatus(); + // SSLEngineResult.Status status = result.getStatus(); + *handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetHandshakeStatus)); + int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetStatus)); switch (status) { case STATUS__OK: - flush(env, sslStream); - checkHandshakeStatus(env, sslStream, getHandshakeStatus(env, sslStream, sslEngineResult)); - if ((*env)->CallIntMethod(env, sslStream->appOutBuffer, g_ByteBufferPositionMethod) > 0) - doWrap(env, sslStream); - break; + { + return Flush(env, sslStream); + } case STATUS__CLOSED: - flush(env, sslStream); - checkHandshakeStatus(env, sslStream, getHandshakeStatus(env, sslStream, sslEngineResult)); - close(env, sslStream); - break; - case STATUS__BUFFER_OVERFLOW: - sslStream->netOutBuffer = ensureRemaining(env, sslStream, sslStream->netOutBuffer, (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSizeMethod)); - doWrap(env, sslStream); - break; - } -} - -static void doUnwrap(JNIEnv* env, SSLStream* sslStream) -{ - /* - if (netInBuffer.position() == 0) { - byte[] tmp = new byte[netInBuffer.limit()]; - - int count = ReadFromInputStream(tmp, 0, tmp.length); - if (count == -1) { - handleEndOfStream(); - return; - } - netInBuffer.put(tmp, 0, count); + (void)Flush(env, sslStream); + (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseOutbound); + return SSLStreamStatus_Closed; } - - netInBuffer.flip(); - final SSLEngineResult result; - try { - result = sslEngine.unwrap(netInBuffer, appInBuffer); - } catch (SSLException e) { - return; + case STATUS__BUFFER_OVERFLOW: + { + // Expand buffer + // int newRemaining = sslSession.getPacketBufferSize(); + int newRemaining = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize); + sslStream->netOutBuffer = EnsureRemaining(env, sslStream, sslStream->netOutBuffer, newRemaining); + return SSLStreamStatus_OK; } - netInBuffer.compact(); - final SSLEngineResult.Status status = result.getStatus(); - switch (status) { - case OK: - checkHandshakeStatus(result.getHandshakeStatus()); - break; - case CLOSED: - checkHandshakeStatus(result.getHandshakeStatus()); - close(); - break; - case BUFFER_UNDERFLOW: - netInBuffer = ensureRemaining(netInBuffer, sslEngine.getSession().getPacketBufferSize()); - doUnwrap(); - break; - case BUFFER_OVERFLOW: - appInBuffer = ensureRemaining(appInBuffer, sslEngine.getSession().getApplicationBufferSize()); - doUnwrap(); - break; + default: + { + LOG_ERROR("Unknown SSLEngineResult status: %d", status); + return SSLStreamStatus_Error; } - */ - - AssertOnJNIExceptions(env); + } +} - if ((*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferPositionMethod) == 0) +static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus) +{ + // if (netInBuffer.position() == 0) + // { + // byte[] tmp = new byte[netInBuffer.limit()]; + // int count = streamReader(tmp, 0, tmp.length); + // netInBuffer.put(tmp, 0, count); + // } + if ((*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferPosition) == 0) { - int netInBufferLimit = (*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferLimitMethod); + int netInBufferLimit = (*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferLimit); jbyteArray tmp = (*env)->NewByteArray(env, netInBufferLimit); uint8_t* tmpNative = (uint8_t*)malloc((size_t)netInBufferLimit); - int count = sslStream->streamReader(tmpNative, 0, (uint32_t)netInBufferLimit); - if (count == -1) + int count = netInBufferLimit; + PAL_SSLStreamStatus status = sslStream->streamReader(tmpNative, &count); + if (status != SSLStreamStatus_OK) { - handleEndOfStream(env, sslStream); - return; + return status; } + (*env)->SetByteArrayRegion(env, tmp, 0, count, (jbyte*)(tmpNative)); - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferPut3Method, tmp, 0, count)); + IGNORE_RETURN( + (*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferPutByteArrayWithLength, tmp, 0, count)); free(tmpNative); (*env)->DeleteLocalRef(env, tmp); } - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferFlipMethod)); - jobject sslEngineResult = (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineUnwrapMethod, sslStream->netInBuffer, sslStream->appInBuffer); - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferCompactMethod)); - - int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, sslEngineResult, g_SSLEngineResultGetStatusMethod)); + // netInBuffer.flip(); + // SSLEngineResult result = sslEngine.unwrap(netInBuffer, appInBuffer); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferFlip)); + jobject result = (*env)->CallObjectMethod( + env, sslStream->sslEngine, g_SSLEngineUnwrap, sslStream->netInBuffer, sslStream->appInBuffer); + if (CheckJNIExceptions(env)) + return SSLStreamStatus_Error; + + // netInBuffer.compact(); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferCompact)); + + // handshakeStatus = result.getHandshakeStatus(); + // SSLEngineResult.Status status = result.getStatus(); + *handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetHandshakeStatus)); + int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetStatus)); switch (status) { case STATUS__OK: - checkHandshakeStatus(env, sslStream, getHandshakeStatus(env, sslStream, sslEngineResult)); - break; + { + return SSLStreamStatus_OK; + } case STATUS__CLOSED: - checkHandshakeStatus(env, sslStream, getHandshakeStatus(env, sslStream, sslEngineResult)); - close(env, sslStream); - break; + { + return Close(env, sslStream); + } case STATUS__BUFFER_UNDERFLOW: - sslStream->netInBuffer = ensureRemaining(env, sslStream, sslStream->netInBuffer, (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSizeMethod)); - doUnwrap(env, sslStream); - break; + { + // Expand buffer + // int newRemaining = sslSession.getPacketBufferSize(); + int newRemaining = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize); + sslStream->netInBuffer = EnsureRemaining(env, sslStream, sslStream->netInBuffer, newRemaining); + return SSLStreamStatus_OK; + } case STATUS__BUFFER_OVERFLOW: - sslStream->appInBuffer = ensureRemaining(env, sslStream, sslStream->appInBuffer, (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetApplicationBufferSizeMethod)); - doUnwrap(env, sslStream); - break; + { + // Expand buffer + // int newRemaining = sslSession.getApplicationBufferSize(); + int newRemaining = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetApplicationBufferSize); + sslStream->appInBuffer = EnsureRemaining(env, sslStream, sslStream->appInBuffer, newRemaining); + return SSLStreamStatus_OK; + } + default: + { + LOG_ERROR("Unknown SSLEngineResult status: %d", status); + return SSLStreamStatus_Error; + } } } -void checkHandshakeStatus(JNIEnv* env, SSLStream* sslStream, int handshakeStatus) +static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* sslStream) { - /* - switch (handshakeStatus) { - case NEED_WRAP: - doWrap(); + assert(env != NULL); + assert(sslStream != NULL); + + PAL_SSLStreamStatus status = SSLStreamStatus_OK; + int handshakeStatus = + GetEnumAsInt(env, (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetHandshakeStatus)); + while (IsHandshaking(handshakeStatus) && status == SSLStreamStatus_OK) + { + switch (handshakeStatus) + { + case HANDSHAKE_STATUS__NEED_WRAP: + status = DoWrap(env, sslStream, &handshakeStatus); break; - case NEED_UNWRAP: - doUnwrap(); + case HANDSHAKE_STATUS__NEED_UNWRAP: + status = DoUnwrap(env, sslStream, &handshakeStatus); break; - case NEED_TASK: - Runnable task; - while ((task = sslEngine.getDelegatedTask()) != null) task.run(); - checkHandshakeStatus(); + case HANDSHAKE_STATUS__NOT_HANDSHAKING: + case HANDSHAKE_STATUS__FINISHED: + status = SSLStreamStatus_OK; break; + case HANDSHAKE_STATUS__NEED_TASK: + assert(0 && "unexpected NEED_TASK handshake status"); } - */ + } + + return status; +} + +static void FreeSSLStream(JNIEnv* env, SSLStream* sslStream) +{ + assert(sslStream != NULL); + ReleaseGRef(env, sslStream->sslContext); + ReleaseGRef(env, sslStream->sslEngine); + ReleaseGRef(env, sslStream->sslSession); + ReleaseGRef(env, sslStream->appOutBuffer); + ReleaseGRef(env, sslStream->netOutBuffer); + ReleaseGRef(env, sslStream->netInBuffer); + ReleaseGRef(env, sslStream->appInBuffer); + free(sslStream); +} + +SSLStream* AndroidCryptoNative_SSLStreamCreate(void) +{ + JNIEnv* env = GetJNIEnv(); + + // TODO: [AndroidCrypto] If we have certificates, get an SSLContext instance with the highest available + // protocol - TLSv1.2 (API level 16+) or TLSv1.3 (API level 29+), use KeyManagerFactory to create key + // managers that will return the certificates, and initialize the SSLContext with the key managers. + + // SSLContext sslContext = SSLContext.getDefault(); + jobject sslContext = (*env)->CallStaticObjectMethod(env, g_SSLContext, g_SSLContextGetDefault); + if (CheckJNIExceptions(env)) + return NULL; + + SSLStream* sslStream = malloc(sizeof(SSLStream)); + memset(sslStream, 0, sizeof(SSLStream)); + sslStream->sslContext = ToGRef(env, sslContext); + return sslStream; +} + +static int32_t AddCertChainToStore( + JNIEnv* env, + jobject store, + uint8_t* pkcs8PrivateKey, + int32_t pkcs8PrivateKeyLen, + PAL_KeyAlgorithm algorithm, + jobject* /*X509Certificate[]*/ certs, + int32_t certsLen) +{ + int32_t ret = FAIL; + INIT_LOCALS(loc, keyBytes, keySpec, algorithmName, keyFactory, privateKey, certArray, alias); + + // byte[] keyBytes = new byte[] { }; + // PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyBytes); + loc[keyBytes] = (*env)->NewByteArray(env, pkcs8PrivateKeyLen); + (*env)->SetByteArrayRegion(env, loc[keyBytes], 0, pkcs8PrivateKeyLen, (jbyte*)pkcs8PrivateKey); + loc[keySpec] = (*env)->NewObject(env, g_PKCS8EncodedKeySpec, g_PKCS8EncodedKeySpecCtor, loc[keyBytes]); - AssertOnJNIExceptions(env); - switch (handshakeStatus) + switch (algorithm) { - case HANDSHAKE_STATUS__NEED_WRAP: - doWrap(env, sslStream); + case PAL_DSA: + loc[algorithmName] = JSTRING("DSA"); break; - case HANDSHAKE_STATUS__NEED_UNWRAP: - doUnwrap(env, sslStream); + case PAL_EC: + loc[algorithmName] = JSTRING("EC"); break; - case HANDSHAKE_STATUS__NEED_TASK: - assert(0 && "unexpected NEED_TASK handshake status"); + case PAL_RSA: + loc[algorithmName] = JSTRING("RSA"); + break; + default: + LOG_ERROR("Unknown key algorithm: %d", algorithm); + goto cleanup; } + + // KeyFactory keyFactory = KeyFactory.getInstance(algorithmName); + // PrivateKey privateKey = keyFactory.generatePrivate(spec); + loc[keyFactory] = (*env)->CallStaticObjectMethod(env, g_KeyFactoryClass, g_KeyFactoryGetInstanceMethod, loc[algorithmName]); + loc[privateKey] = (*env)->CallObjectMethod(env, loc[keyFactory], g_KeyFactoryGenPrivateMethod, loc[keySpec]); + + // X509Certificate[] certArray = new X509Certificate[certsLen]; + loc[certArray] = (*env)->NewObjectArray(env, certsLen, g_X509CertClass, NULL); + for (int i = 0; i < certsLen; ++i) + { + (*env)->SetObjectArrayElement(env, loc[certArray], i, certs[i]); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + } + + // store.setKeyEntry("SSLCertificateContext", privateKey, null, certArray); + loc[alias] = JSTRING("SSLCertificateContext"); + (*env)->CallVoidMethod(env, store, g_KeyStoreSetKeyEntry, loc[alias], loc[privateKey], NULL, loc[certArray]); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + ret = SUCCESS; + +cleanup: + RELEASE_LOCALS(loc, env); + return ret; } -SSLStream* AndroidCryptoNative_SSLStreamCreateAndStartHandshake( - STREAM_READER streamReader, - STREAM_WRITER streamWriter, - int tlsVersion, - int appOutBufferSize, - int appInBufferSize) +SSLStream* AndroidCryptoNative_SSLStreamCreateWithCertificates(uint8_t* pkcs8PrivateKey, int32_t pkcs8PrivateKeyLen, PAL_KeyAlgorithm algorithm, jobject* /*X509Certificate[]*/ certs, int32_t certsLen) { + SSLStream* sslStream = NULL; JNIEnv* env = GetJNIEnv(); - /* - SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); - sslContext.init(null, new TrustManage r[]{trustAllCerts}, null); - this.sslEngine = sslContext.createSSLEngine(); - this.sslEngine.setUseClientMode(true); - SSLSession sslSession = sslEngine.getSession(); - final int applicationBufferSize = sslSession.getApplicationBufferSize(); - final int packetBufferSize = sslSession.getPacketBufferSize(); - this.appOutBuffer = ByteBuffer.allocate(appOutBufferSize); - this.netOutBuffer = ByteBuffer.allocate(packetBufferSize); - this.netInBuffer = ByteBuffer.allocate(packetBufferSize); - this.appInBuffer = ByteBuffer.allocate(Math.max(applicationBufferSize, appInBufferSize)); - sslEngine.beginHandshake(); - */ - SSLStream* sslStream = malloc(sizeof(SSLStream)); + INIT_LOCALS(loc, tls13, sslContext, ksType, keyStore, kmfType, kmf, keyManagers); - jobject tlsVerStr = NULL; - if (tlsVersion == 11) - tlsVerStr = JSTRING("TLSv1.1"); - else if (tlsVersion == 12) - tlsVerStr = JSTRING("TLSv1.2"); - else if (tlsVersion == 13) - tlsVerStr = JSTRING("TLSv1.3"); - else - assert(0 && "unknown tlsVersion"); - - sslStream->sslContext = ToGRef(env, (*env)->CallStaticObjectMethod(env, g_SSLContext, g_SSLContextGetInstanceMethod, tlsVerStr)); + // SSLContext sslContext = SSLContext.getInstance("TLSv1.3"); + loc[tls13] = JSTRING("TLSv1.3"); + loc[sslContext] = (*env)->CallStaticObjectMethod(env, g_SSLContext, g_SSLContextGetInstanceMethod, loc[tls13]); + if (TryClearJNIExceptions(env)) + { + // TLSv1.3 is only supported on API level 29+ - fall back to TLSv1.2 (which is supported on API level 16+) + // sslContext = SSLContext.getInstance("TLSv1.2"); + jobject tls12 = JSTRING("TLSv1.2"); + loc[sslContext] = (*env)->CallStaticObjectMethod(env, g_SSLContext, g_SSLContextGetInstanceMethod, tls12); + ReleaseLRef(env, tls12); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + } - // TODO: set TrustManager[] argument to be able to intercept cert validation process (and callback to C#). - (*env)->CallVoidMethod(env, sslStream->sslContext, g_SSLContextInitMethod, NULL, NULL, NULL); - sslStream->sslEngine = ToGRef(env, (*env)->CallObjectMethod(env, sslStream->sslContext, g_SSLContextCreateSSLEngineMethod)); - sslStream->sslSession = ToGRef(env, (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetSessionMethod)); + // String ksType = KeyStore.getDefaultType(); + // KeyStore keyStore = KeyStore.getInstance(ksType); + // keyStore.load(null, null); + loc[ksType] = (*env)->CallStaticObjectMethod(env, g_KeyStoreClass, g_KeyStoreGetDefaultType); + loc[keyStore] = (*env)->CallStaticObjectMethod(env, g_KeyStoreClass, g_KeyStoreGetInstance, loc[ksType]); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + (*env)->CallVoidMethod(env, loc[keyStore], g_KeyStoreLoad, NULL, NULL); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + int32_t status = AddCertChainToStore(env, loc[keyStore], pkcs8PrivateKey, pkcs8PrivateKeyLen, algorithm, certs, certsLen); + if (status != SUCCESS) + goto cleanup; + + // String kmfType = "PKIX"; + // KeyManagerFactory kmf = KeyManagerFactory.getInstance(kmfType); + loc[kmfType] = JSTRING("PKIX"); + loc[kmf] = (*env)->CallStaticObjectMethod(env, g_KeyManagerFactory, g_KeyManagerFactoryGetInstance, loc[kmfType]); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + // kmf.init(keyStore, null); + (*env)->CallVoidMethod(env, loc[kmf], g_KeyManagerFactoryInit, loc[keyStore], NULL); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + // KeyManager[] keyManagers = kmf.getKeyManagers(); + // sslContext.init(keyManagers, null, null); + loc[keyManagers] = (*env)->CallObjectMethod(env, loc[kmf], g_KeyManagerFactoryGetKeyManagers); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + (*env)->CallVoidMethod(env, loc[sslContext], g_SSLContextInitMethod, loc[keyManagers], NULL, NULL); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + sslStream = malloc(sizeof(SSLStream)); + memset(sslStream, 0, sizeof(SSLStream)); + sslStream->sslContext = ToGRef(env, loc[sslContext]); + loc[sslContext] = NULL; + +cleanup: + RELEASE_LOCALS(loc, env); + return sslStream; +} - int applicationBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetApplicationBufferSizeMethod); - int packetBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSizeMethod); +int32_t AndroidCryptoNative_SSLStreamInitialize(SSLStream* sslStream, + bool isServer, + STREAM_READER streamReader, + STREAM_WRITER streamWriter, + int appBufferSize) +{ + assert(sslStream != NULL); + assert(sslStream->sslContext != NULL); + assert(sslStream->sslEngine == NULL); + assert(sslStream->sslSession == NULL); - (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineSetUseClientModeMethod, true); + int32_t ret = FAIL; + JNIEnv* env = GetJNIEnv(); - sslStream->appOutBuffer = ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocateMethod, appOutBufferSize)); - sslStream->netOutBuffer = ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocateMethod, packetBufferSize)); - sslStream->appInBuffer = ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocateMethod, - applicationBufferSize > appInBufferSize ? applicationBufferSize : appInBufferSize)); - sslStream->netInBuffer = ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocateMethod, packetBufferSize)); + // SSLEngine sslEngine = sslContext.createSSLEngine(); + // sslEngine.setUseClientMode(!isServer); + jobject sslEngine = (*env)->CallObjectMethod(env, sslStream->sslContext, g_SSLContextCreateSSLEngineMethod); + ON_EXCEPTION_PRINT_AND_GOTO(exit); + sslStream->sslEngine = ToGRef(env, sslEngine); + (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineSetUseClientMode, !isServer); + ON_EXCEPTION_PRINT_AND_GOTO(exit); + + // SSLSession sslSession = sslEngine.getSession(); + sslStream->sslSession = ToGRef(env, (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetSession)); + + // int applicationBufferSize = sslSession.getApplicationBufferSize(); + // int packetBufferSize = sslSession.getPacketBufferSize(); + int applicationBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetApplicationBufferSize); + int packetBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize); + + // ByteBuffer appInBuffer = ByteBuffer.allocate(Math.max(applicationBufferSize, appBufferSize)); + // ByteBuffer appOutBuffer = ByteBuffer.allocate(appBufferSize); + // ByteBuffer netOutBuffer = ByteBuffer.allocate(packetBufferSize); + // ByteBuffer netInBuffer = ByteBuffer.allocate(packetBufferSize); + int appInBufferSize = applicationBufferSize > appBufferSize ? applicationBufferSize : appBufferSize; + sslStream->appInBuffer = + ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocate, appInBufferSize)); + sslStream->appOutBuffer = + ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocate, appBufferSize)); + sslStream->netOutBuffer = + ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocate, packetBufferSize)); + sslStream->netInBuffer = + ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocate, packetBufferSize)); sslStream->streamReader = streamReader; sslStream->streamWriter = streamWriter; - (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineBeginHandshakeMethod); + ret = SUCCESS; - checkHandshakeStatus(env, sslStream, getHandshakeStatus(env, sslStream, NULL)); - (*env)->DeleteLocalRef(env, tlsVerStr); - AssertOnJNIExceptions(env); - return sslStream; +exit: + return ret; } -int AndroidCryptoNative_SSLStreamRead(SSLStream* sslStream, uint8_t* buffer, int offset, int length) +int32_t AndroidCryptoNative_SSLStreamConfigureParameters(SSLStream* sslStream, char* targetHost) { + assert(sslStream != NULL); + assert(targetHost != NULL); + JNIEnv* env = GetJNIEnv(); + int32_t ret = FAIL; + INIT_LOCALS(loc, hostStr, nameList, hostName, params); + + // ArrayList nameList = new ArrayList(); + // SNIHostName hostName = new SNIHostName(targetHost); + // nameList.add(hostName); + loc[hostStr] = JSTRING(targetHost); + loc[nameList] = (*env)->NewObject(env, g_ArrayListClass, g_ArrayListCtor); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + loc[hostName] = (*env)->NewObject(env, g_SNIHostName, g_SNIHostNameCtor, loc[hostStr]); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + (*env)->CallBooleanMethod(env, loc[nameList], g_ArrayListAdd, loc[hostName]); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + // SSLParameters params = new SSLParameters(); + // params.setServerNames(nameList); + // sslEngine.setSSLParameters(params); + loc[params] = (*env)->NewObject(env, g_SSLParametersClass, g_SSLParametersCtor); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + (*env)->CallVoidMethod(env, loc[params], g_SSLParametersSetServerNames, loc[nameList]); + (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineSetSSLParameters, loc[params]); + + ret = SUCCESS; + +cleanup: + RELEASE_LOCALS(loc, env); + return ret; +} + +PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamHandshake(SSLStream* sslStream) +{ + assert(sslStream != NULL); + JNIEnv* env = GetJNIEnv(); + + // sslEngine.beginHandshake(); + (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineBeginHandshake); + if (CheckJNIExceptions(env)) + return SSLStreamStatus_Error; + + return DoHandshake(env, sslStream); +} + +PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamRead(SSLStream* sslStream, uint8_t* buffer, int length, int* read) +{ + assert(sslStream != NULL); + assert(read != NULL); + + jbyteArray data = NULL; + JNIEnv* env = GetJNIEnv(); + PAL_SSLStreamStatus ret = SSLStreamStatus_Error; + *read = 0; + /* - while (true) { + appInBuffer.flip(); + if (appInBuffer.remaining() == 0) { + appInBuffer.compact(); + DoUnwrap(); appInBuffer.flip(); - try { - if (appInBuffer.remaining() > 0) { - byte[] data = new byte[appInBuffer.remaining()]; - appInBuffer.get(data); - return data; - } - } finally { - appInBuffer.compact(); - } - doUnwrap(); + } + if (appInBuffer.remaining() > 0) { + byte[] data = new byte[appInBuffer.remaining()]; + appInBuffer.get(data); + appInBuffer.compact(); + return SSLStreamStatus_OK; + } else { + return SSLStreamStatus_NeedData; } */ - AssertOnJNIExceptions(env); - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferFlipMethod)); - int rem = (*env)->CallIntMethod(env, sslStream->appInBuffer, g_ByteBufferRemainingMethod); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferFlip)); + int rem = (*env)->CallIntMethod(env, sslStream->appInBuffer, g_ByteBufferRemaining); + if (rem == 0) + { + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferCompact)); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + int handshakeStatus; + PAL_SSLStreamStatus unwrapStatus = DoUnwrap(env, sslStream, &handshakeStatus); + if (unwrapStatus != SSLStreamStatus_OK) + { + ret = unwrapStatus; + goto cleanup; + } + + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferFlip)); + + if (IsHandshaking(handshakeStatus)) + { + ret = SSLStreamStatus_Renegotiate; + goto cleanup; + } + + rem = (*env)->CallIntMethod(env, sslStream->appInBuffer, g_ByteBufferRemaining); + } + if (rem > 0) { - jbyteArray data = (*env)->NewByteArray(env, rem); - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferGetMethod, data)); - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferCompactMethod)); - (*env)->GetByteArrayRegion(env, data, 0, rem, (jbyte*) buffer); - AssertOnJNIExceptions(env); - return rem; + data = (*env)->NewByteArray(env, rem); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferGet, data)); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferCompact)); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + (*env)->GetByteArrayRegion(env, data, 0, rem, (jbyte*)buffer); + *read = rem; + ret = SSLStreamStatus_OK; } else { - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->appInBuffer, g_ByteBufferCompactMethod)); - doUnwrap(env, sslStream); - AssertOnJNIExceptions(env); - return AndroidCryptoNative_SSLStreamRead(sslStream, buffer, offset, length); + ret = SSLStreamStatus_NeedData; } + +cleanup: + ReleaseLRef(env, data); + return ret; } -void AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uint8_t* buffer, int offset, int length) +PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uint8_t* buffer, int length) { - /* - appOutBuffer.put(message); - doWrap(); - */ + assert(sslStream != NULL); JNIEnv* env = GetJNIEnv(); + PAL_SSLStreamStatus ret = SSLStreamStatus_Error; + + // byte[] data = new byte[] { } + // appOutBuffer.put(data); jbyteArray data = (*env)->NewByteArray(env, length); - (*env)->SetByteArrayRegion(env, data, 0, length, (jbyte*)(buffer + offset)); - (*env)->DeleteLocalRef(env, (*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPut2Method, data)); + (*env)->SetByteArrayRegion(env, data, 0, length, (jbyte*)buffer); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutByteArray, data)); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + int handshakeStatus; + ret = DoWrap(env, sslStream, &handshakeStatus); + if (ret == SSLStreamStatus_OK && IsHandshaking(handshakeStatus)) + { + ret = SSLStreamStatus_Renegotiate; + } + +cleanup: (*env)->DeleteLocalRef(env, data); - doWrap(env, sslStream); - AssertOnJNIExceptions(env); + return ret; } void AndroidCryptoNative_SSLStreamRelease(SSLStream* sslStream) { + if (sslStream == NULL) + return; + JNIEnv* env = GetJNIEnv(); - ReleaseGRef(env, sslStream->sslContext); - ReleaseGRef(env, sslStream->sslEngine); - ReleaseGRef(env, sslStream->sslSession); - ReleaseGRef(env, sslStream->appOutBuffer); - ReleaseGRef(env, sslStream->netOutBuffer); - ReleaseGRef(env, sslStream->netInBuffer); - ReleaseGRef(env, sslStream->appInBuffer); - free(sslStream); - AssertOnJNIExceptions(env); + FreeSSLStream(env, sslStream); +} + +int32_t AndroidCryptoNative_SSLStreamGetApplicationProtocol(SSLStream* sslStream, uint8_t* out, int* outLen) +{ + assert(sslStream != NULL); + + JNIEnv* env = GetJNIEnv(); + int32_t ret = FAIL; + + // String protocol = sslEngine.getApplicationProtocol(); + jstring protocol = (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetApplicationProtocol); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + if (protocol == NULL) + goto cleanup; + + jsize len = (*env)->GetStringUTFLength(env, protocol); + bool insufficientBuffer = *outLen < len; + *outLen = len; + if (insufficientBuffer) + return INSUFFICIENT_BUFFER; + + (*env)->GetStringUTFRegion(env, protocol, 0, len, (char*)out); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + ret = SUCCESS; + +cleanup: + (*env)->DeleteLocalRef(env, protocol); + return ret; +} + +int32_t AndroidCryptoNative_SSLStreamGetCipherSuite(SSLStream* sslStream, uint16_t** out) +{ + assert(sslStream != NULL); + assert(out != NULL); + + JNIEnv* env = GetJNIEnv(); + int32_t ret = FAIL; + *out = NULL; + + // String cipherSuite = sslSession.getCipherSuite(); + jstring cipherSuite = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetCipherSuite); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + *out = AllocateString(env, cipherSuite); + + ret = SUCCESS; + +cleanup: + (*env)->DeleteLocalRef(env, cipherSuite); + return ret; +} + +int32_t AndroidCryptoNative_SSLStreamGetProtocol(SSLStream* sslStream, uint16_t** out) +{ + assert(sslStream != NULL); + assert(out != NULL); + + JNIEnv* env = GetJNIEnv(); + int32_t ret = FAIL; + *out = NULL; + + // String protocol = sslSession.getProtocol(); + jstring protocol = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetProtocol); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + *out = AllocateString(env, protocol); + + ret = SUCCESS; + +cleanup: + (*env)->DeleteLocalRef(env, protocol); + return ret; +} + +int32_t AndroidCryptoNative_SSLStreamGetPeerCertificate(SSLStream* sslStream, jobject* out) +{ + assert(sslStream != NULL); + assert(out != NULL); + + JNIEnv* env = GetJNIEnv(); + int32_t ret = FAIL; + *out = NULL; + + // Certificate[] certs = sslSession.getPeerCertificates(); + // out = certs[0]; + jobjectArray certs = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetPeerCertificates); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + jsize len = (*env)->GetArrayLength(env, certs); + if (len > 0) + { + // First element is the peer's own certificate + jobject cert = (*env)->GetObjectArrayElement(env, certs, 0); + *out = ToGRef(env, cert); + } + + ret = SUCCESS; + +cleanup: + (*env)->DeleteLocalRef(env, certs); + return ret; +} + +int32_t AndroidCryptoNative_SSLStreamGetPeerCertificates(SSLStream* sslStream, jobject** out, int* outLen) +{ + assert(sslStream != NULL); + assert(out != NULL); + + JNIEnv* env = GetJNIEnv(); + int32_t ret = FAIL; + *out = NULL; + *outLen = 0; + + // Certificate[] certs = sslSession.getPeerCertificates(); + // for (int i = 0; i < certs.length; i++) { + // out[i] = certs[i]; + // } + jobjectArray certs = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetPeerCertificates); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + jsize len = (*env)->GetArrayLength(env, certs); + *outLen = len; + if (len > 0) + { + *out = malloc(sizeof(jobject) * (size_t)len); + for (int i = 0; i < len; i++) + { + jobject cert = (*env)->GetObjectArrayElement(env, certs, i); + (*out)[i] = ToGRef(env, cert); + } + } + + ret = SUCCESS; + +cleanup: + (*env)->DeleteLocalRef(env, certs); + return ret; +} + +bool AndroidCryptoNative_SSLStreamVerifyHostname(SSLStream* sslStream, char* hostname) +{ + assert(sslStream != NULL); + assert(hostname != NULL); + JNIEnv* env = GetJNIEnv(); + + bool ret = false; + INIT_LOCALS(loc, name, verifier); + + // HostnameVerifier verifier = HttpsURLConnection.getDefaultHostnameVerifier(); + // return verifier.verify(hostname, sslSession); + loc[name] = JSTRING(hostname); + loc[verifier] = + (*env)->CallStaticObjectMethod(env, g_HttpsURLConnection, g_HttpsURLConnectionGetDefaultHostnameVerifier); + ret = (*env)->CallBooleanMethod(env, loc[verifier], g_HostnameVerifierVerify, loc[name], sslStream->sslSession); + + RELEASE_LOCALS(loc, env); + return ret; +} + +bool AndroidCryptoNative_SSLStreamShutdown(SSLStream* sslStream) +{ + assert(sslStream != NULL); + JNIEnv* env = GetJNIEnv(); + + PAL_SSLStreamStatus status = Close(env, sslStream); + return status == SSLStreamStatus_Closed; +} + +static uint16_t* AllocateString(JNIEnv* env, jstring source) +{ + if (source == NULL) + return NULL; + + // Length with null terminator + jsize len = (*env)->GetStringLength(env, source); + + // +1 for null terminator. + uint16_t* buffer = malloc(sizeof(uint16_t) * (size_t)(len + 1)); + buffer[len] = '\0'; + + (*env)->GetStringRegion(env, source, 0, len, (jchar*)buffer); + return buffer; } diff --git a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.h b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.h index 348760cfdc0a0a..afe966935c9484 100644 --- a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.h +++ b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_sslstream.h @@ -4,9 +4,10 @@ #pragma once #include "pal_jni.h" +#include "pal_x509.h" -typedef void (*STREAM_WRITER)(uint8_t*, uint32_t, uint32_t); -typedef int (*STREAM_READER)(uint8_t*, uint32_t, uint32_t); +typedef void (*STREAM_WRITER)(uint8_t*, int32_t); +typedef int (*STREAM_READER)(uint8_t*, int32_t*); typedef struct SSLStream { @@ -21,24 +22,133 @@ typedef struct SSLStream STREAM_WRITER streamWriter; } SSLStream; -#define TLS11 11 -#define TLS12 12 -#define TLS13 13 - -// javax/net/ssl/SSLEngineResult$HandshakeStatus -#define HANDSHAKE_STATUS__NOT_HANDSHAKING 0 -#define HANDSHAKE_STATUS__FINISHED 1 -#define HANDSHAKE_STATUS__NEED_TASK 2 -#define HANDSHAKE_STATUS__NEED_WRAP 3 -#define HANDSHAKE_STATUS__NEED_UNWRAP 4 - -// javax/net/ssl/SSLEngineResult$Status -#define STATUS__BUFFER_UNDERFLOW 0 -#define STATUS__BUFFER_OVERFLOW 1 -#define STATUS__OK 2 -#define STATUS__CLOSED 3 - -PALEXPORT SSLStream* AndroidCryptoNative_SSLStreamCreateAndStartHandshake(STREAM_READER streamReader, STREAM_WRITER streamWriter, int tlsVersion, int appOutBufferSize, int appInBufferSize); -PALEXPORT int AndroidCryptoNative_SSLStreamRead(SSLStream* sslStream, uint8_t* buffer, int offset, int length); -PALEXPORT void AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uint8_t* buffer, int offset, int length); +// Matches managed PAL_SSLStreamStatus enum +enum +{ + SSLStreamStatus_OK = 0, + SSLStreamStatus_NeedData = 1, + SSLStreamStatus_Error = 2, + SSLStreamStatus_Renegotiate = 3, + SSLStreamStatus_Closed = 4, +}; +typedef int32_t PAL_SSLStreamStatus; + +/* +Create an SSL context + +Returns NULL on failure +*/ +PALEXPORT SSLStream* AndroidCryptoNative_SSLStreamCreate(void); + +/* +Create an SSL context with the specified certificates + +Returns NULL on failure +*/ +PALEXPORT SSLStream* AndroidCryptoNative_SSLStreamCreateWithCertificates(uint8_t* pkcs8PrivateKey, int32_t pkcs8PrivateKeyLen, PAL_KeyAlgorithm algorithm, jobject* /*X509Certificate[]*/ certs, int32_t certsLen); + +/* +Initialize an SSL context + - isServer : true if the context should be created in server mode + - streamReader : callback for reading data from the connection + - streamWriter : callback for writing data to the connection + - appBufferSize : initial buffer size for applicaiton data + +Returns 1 on success, 0 otherwise +*/ +PALEXPORT int32_t AndroidCryptoNative_SSLStreamInitialize(SSLStream* sslStream, + bool isServer, + STREAM_READER streamReader, + STREAM_WRITER streamWriter, + int appBufferSize); + +/* +Set configuration parameters + - targetHost : SNI host name + +Returns 1 on success, 0 otherwise +*/ +PALEXPORT int32_t AndroidCryptoNative_SSLStreamConfigureParameters(SSLStream* sslStream, char* targetHost); + +/* +Start or continue the TLS handshake +*/ +PALEXPORT PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamHandshake(SSLStream* sslStream); + +/* +Read bytes from the connection into a buffer + - buffer : buffer to populate with the bytes read from the connection + - length : maximum number of bytes to read + - read : [out] number of bytes read from the connection and written into the buffer + +Unless data from a previous incomplete read is present, this will invoke the STREAM_READER callback. +*/ +PALEXPORT PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamRead(SSLStream* sslStream, + uint8_t* buffer, + int length, + int* read); +/* +Encodes bytes from a buffer + - buffer : data to encode + - length : length of buffer + +This will invoke the STREAM_WRITER callback with the processed data. +*/ +PALEXPORT PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uint8_t* buffer, int length); + +/* +Release the SSL context +*/ PALEXPORT void AndroidCryptoNative_SSLStreamRelease(SSLStream* sslStream); + +/* +Get the negotiated application protocol for the current session + +Returns 1 on success, 0 otherwise +*/ +PALEXPORT int32_t AndroidCryptoNative_SSLStreamGetApplicationProtocol(SSLStream* sslStream, uint8_t* out, int* outLen); + +/* +Get the name of the cipher suite for the current session + +Returns 1 on success, 0 otherwise +*/ +PALEXPORT int32_t AndroidCryptoNative_SSLStreamGetCipherSuite(SSLStream* sslStream, uint16_t** out); + +/* +Get the standard name of the protocol for the current session (e.g. TLSv1.2) + +Returns 1 on success, 0 otherwise +*/ +PALEXPORT int32_t AndroidCryptoNative_SSLStreamGetProtocol(SSLStream* sslStream, uint16_t** out); + +/* +Get the peer certificate for the current session + +Returns 1 on success, 0 otherwise +*/ +PALEXPORT int32_t AndroidCryptoNative_SSLStreamGetPeerCertificate(SSLStream* sslStream, + jobject* /*X509Certificate*/ out); + +/* +Get the peer certificates for the current session + +The peer's own certificate will be first, followed by any certificate authorities. + +Returns 1 on success, 0 otherwise +*/ +PALEXPORT int32_t AndroidCryptoNative_SSLStreamGetPeerCertificates(SSLStream* sslStream, + jobject** /*X509Certificate[]*/ out, + int* outLen); + +/* +Verify hostname using the peer certificate for the current session + +Returns true if hostname matches, false otherwise +*/ +PALEXPORT bool AndroidCryptoNative_SSLStreamVerifyHostname(SSLStream* sslStream, char* hostname); + +/* +Shut down the session +*/ +PALEXPORT bool AndroidCryptoNative_SSLStreamShutdown(SSLStream* sslStream); diff --git a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_x509.c b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_x509.c index 9cd92dadecdab2..f2468882674a28 100644 --- a/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_x509.c +++ b/src/libraries/Native/Unix/System.Security.Cryptography.Native.Android/pal_x509.c @@ -12,8 +12,6 @@ #include #include -#define INSUFFICIENT_BUFFER -1 - static int32_t PopulateByteArray(JNIEnv* env, jbyteArray source, uint8_t* dest, int32_t* len); static void FindCertStart(const uint8_t** buffer, int32_t* len); diff --git a/src/libraries/System.Net.Security/src/System.Net.Security.csproj b/src/libraries/System.Net.Security/src/System.Net.Security.csproj index 6c0a5b530fdf59..be5811068e77db 100644 --- a/src/libraries/System.Net.Security/src/System.Net.Security.csproj +++ b/src/libraries/System.Net.Security/src/System.Net.Security.csproj @@ -331,8 +331,12 @@ Link="Common\System\Net\Security\Unix\SafeFreeSslCredentials.cs" /> + + diff --git a/src/libraries/System.Net.Security/src/System/Net/CertificateValidationPal.Android.cs b/src/libraries/System.Net.Security/src/System/Net/CertificateValidationPal.Android.cs index 0289095ab53666..7643482f9336f5 100644 --- a/src/libraries/System.Net.Security/src/System/Net/CertificateValidationPal.Android.cs +++ b/src/libraries/System.Net.Security/src/System/Net/CertificateValidationPal.Android.cs @@ -24,10 +24,17 @@ internal static SslPolicyErrors VerifyCertificateProperties( ? SslPolicyErrors.None : SslPolicyErrors.RemoteCertificateChainErrors; - if (!checkCertName) - return errors; + if (checkCertName) + { + System.Diagnostics.Debug.Assert(hostName != null); + SafeDeleteSslContext sslContext = (SafeDeleteSslContext)securityContext; + if (!Interop.AndroidCrypto.SSLStreamVerifyHostname(sslContext.SslContext, hostName!)) + { + errors |= SslPolicyErrors.RemoteCertificateNameMismatch; + } + } - throw new NotImplementedException(nameof(VerifyCertificateProperties)); + return errors; } // @@ -63,7 +70,39 @@ internal static SslPolicyErrors VerifyCertificateProperties( if (sslContext == null) return null; - throw new NotImplementedException(nameof(GetRemoteCertificate)); + X509Certificate2? cert = null; + if (remoteCertificateStore == null) + { + // Constructing a new X509Certificate2 adds a global reference to the pointer, so we dispose this handle + using (SafeX509Handle handle = Interop.AndroidCrypto.SSLStreamGetPeerCertificate(sslContext)) + { + if (!handle.IsInvalid) + { + cert = new X509Certificate2(handle.DangerousGetHandle()); + } + } + } + else + { + IntPtr[] ptrs = Interop.AndroidCrypto.SSLStreamGetPeerCertificates(sslContext); + if (ptrs.Length > 0) + { + // This is intentionally a different object from the cert added to the remote certificate store + // to match the behaviour on other platforms. + cert = new X509Certificate2(ptrs[0]); + foreach (IntPtr ptr in ptrs) + { + // Constructing a new X509Certificate2 adds a global reference to the pointer, so we dispose this handle + using (var handle = new SafeX509Handle(ptr)) + { + remoteCertificateStore.Add(new X509Certificate2(handle.DangerousGetHandle())); + } + } + + } + } + + return cert; } // diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/Pal.Android/SafeDeleteSslContext.cs b/src/libraries/System.Net.Security/src/System/Net/Security/Pal.Android/SafeDeleteSslContext.cs index 016b33b4ee3dab..776bd6358ca7db 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/Pal.Android/SafeDeleteSslContext.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/Pal.Android/SafeDeleteSslContext.cs @@ -6,15 +6,22 @@ using System.Net.Http; using System.Net.Security; using System.Security.Authentication; +using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using Microsoft.Win32.SafeHandles; +using PAL_KeyAlgorithm = Interop.AndroidCrypto.PAL_KeyAlgorithm; +using PAL_SSLStreamStatus = Interop.AndroidCrypto.PAL_SSLStreamStatus; + namespace System.Net { internal sealed class SafeDeleteSslContext : SafeDeleteContext { private const int InitialBufferSize = 2048; + private readonly SafeSslHandle _sslContext; + private readonly Interop.AndroidCrypto.SSLReadCallback _readCallback; + private readonly Interop.AndroidCrypto.SSLWriteCallback _writeCallback; private ArrayBuffer _inputBuffer = new ArrayBuffer(InitialBufferSize); private ArrayBuffer _outputBuffer = new ArrayBuffer(InitialBufferSize); @@ -26,8 +33,23 @@ public SafeDeleteSslContext(SafeFreeSslCredentials credential, SslAuthentication { Debug.Assert((credential != null) && !credential.IsInvalid, "Invalid credential used in SafeDeleteSslContext"); - _sslContext = new SafeSslHandle(); - throw new NotImplementedException(nameof(SafeDeleteSslContext)); + try + { + unsafe + { + _readCallback = ReadFromConnection; + _writeCallback = WriteToConnection; + } + + _sslContext = CreateSslContext(credential); + InitializeSslContext(_sslContext, _readCallback, _writeCallback, credential, authOptions); + } + catch (Exception ex) + { + Debug.Write("Exception Caught. - " + ex); + Dispose(); + throw; + } } public override bool IsInvalid => _sslContext?.IsInvalid ?? true; @@ -48,7 +70,7 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } - private unsafe void WriteToConnection(byte* data, int offset, int dataLength) + private unsafe void WriteToConnection(byte* data, int dataLength) { var inputBuffer = new ReadOnlySpan(data, dataLength); @@ -57,19 +79,25 @@ private unsafe void WriteToConnection(byte* data, int offset, int dataLength) _outputBuffer.Commit(dataLength); } - private unsafe int ReadFromConnection(byte* data, int offset, int dataLength) + private unsafe PAL_SSLStreamStatus ReadFromConnection(byte* data, int* dataLength) { - if (dataLength == 0) - return 0; + int toRead = *dataLength; + if (toRead == 0) + return PAL_SSLStreamStatus.OK; if (_inputBuffer.ActiveLength == 0) - return 0; + { + *dataLength = 0; + return PAL_SSLStreamStatus.NeedData; + } - int toRead = Math.Min(dataLength, _inputBuffer.ActiveLength); + toRead = Math.Min(toRead, _inputBuffer.ActiveLength); _inputBuffer.ActiveSpan.Slice(0, toRead).CopyTo(new Span(data, toRead)); _inputBuffer.Discard(toRead); - return toRead; + + *dataLength = toRead; + return PAL_SSLStreamStatus.OK; } internal void Write(ReadOnlySpan buf) @@ -108,5 +136,81 @@ internal int ReadPendingWrites(byte[] buf, int offset, int count) return limit; } + + private static SafeSslHandle CreateSslContext(SafeFreeSslCredentials credential) + { + if (credential.CertificateContext == null) + { + return Interop.AndroidCrypto.SSLStreamCreate(); + } + + SslStreamCertificateContext context = credential.CertificateContext; + X509Certificate2 cert = context.Certificate; + Debug.Assert(context.Certificate.HasPrivateKey); + + PAL_KeyAlgorithm algorithm; + byte[] keyBytes; + using (AsymmetricAlgorithm key = GetPrivateKeyAlgorithm(cert, out algorithm)) + { + keyBytes = key.ExportPkcs8PrivateKey(); + } + IntPtr[] ptrs = new IntPtr[context.IntermediateCertificates.Length + 1]; + ptrs[0] = cert.Handle; + for (int i = 0; i < context.IntermediateCertificates.Length; i++) + { + ptrs[i + 1] = context.IntermediateCertificates[i].Handle; + } + + return Interop.AndroidCrypto.SSLStreamCreateWithCertificates(keyBytes, algorithm, ptrs); + } + + private static AsymmetricAlgorithm GetPrivateKeyAlgorithm(X509Certificate2 cert, out PAL_KeyAlgorithm algorithm) + { + AsymmetricAlgorithm? key = cert.GetRSAPrivateKey(); + if (key != null) + { + algorithm = PAL_KeyAlgorithm.RSA; + return key; + } + key = cert.GetECDsaPrivateKey(); + if (key != null) + { + algorithm = PAL_KeyAlgorithm.EC; + return key; + } + key = cert.GetDSAPrivateKey(); + if (key != null) + { + algorithm = PAL_KeyAlgorithm.DSA; + return key; + } + throw new NotSupportedException(SR.net_ssl_io_no_server_cert); + } + + private static void InitializeSslContext( + SafeSslHandle handle, + Interop.AndroidCrypto.SSLReadCallback readCallback, + Interop.AndroidCrypto.SSLWriteCallback writeCallback, + SafeFreeSslCredentials credential, + SslAuthenticationOptions authOptions) + { + bool isServer = authOptions.IsServer; + + if (authOptions.ApplicationProtocols != null + || authOptions.CipherSuitesPolicy != null + || credential.Protocols != SslProtocols.None + || (isServer && authOptions.RemoteCertRequired)) + { + // TODO: [AndroidCrypto] Handle non-system-default options + throw new NotImplementedException(nameof(SafeDeleteSslContext)); + } + + Interop.AndroidCrypto.SSLStreamInitialize(handle, isServer, readCallback, writeCallback, InitialBufferSize); + + if (!isServer && !string.IsNullOrEmpty(authOptions.TargetHost)) + { + Interop.AndroidCrypto.SSLStreamConfigureParameters(handle, authOptions.TargetHost); + } + } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs index 1d468dba8a5743..9635907c86f0b3 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs @@ -11,7 +11,24 @@ internal sealed partial class SslConnectionInfo { public SslConnectionInfo(SafeSslHandle sslContext) { - throw new NotImplementedException(nameof(SslConnectionInfo)); + string protocolString = Interop.AndroidCrypto.SSLStreamGetProtocol(sslContext); + SslProtocols protocol = protocolString switch + { +#pragma warning disable 0618 // Ssl2 and Ssl3 are deprecated. + "SSLv2" => SslProtocols.Ssl2, + "SSLv3" => SslProtocols.Ssl3, +#pragma warning restore + "TLSv1" => SslProtocols.Tls, + "TLSv1.1" => SslProtocols.Tls11, + "TLSv1.2" => SslProtocols.Tls12, + "TLSv1.3" => SslProtocols.Tls13, + _ => SslProtocols.None, + }; + Protocol = (int)protocol; + + // Enum value names should match the cipher suite name, so we just parse the + string cipherSuite = Interop.AndroidCrypto.SSLStreamGetCipherSuite(sslContext); + MapCipherSuite(Enum.Parse(cipherSuite)); } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs index 6245f00ae4570a..5890e1ccc14618 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs @@ -1,12 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers; using System.Diagnostics; using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; using System.Security.Cryptography.X509Certificates; +using PAL_SSLStreamStatus = Interop.AndroidCrypto.PAL_SSLStreamStatus; + namespace System.Net.Security { internal static class SslStreamPal @@ -58,7 +59,7 @@ public static SafeFreeCredentials AcquireCredentialsHandle( if (context == null) return null; - throw new NotImplementedException(nameof(GetNegotiatedApplicationProtocol)); + return Interop.AndroidCrypto.SSLStreamGetApplicationProtocol(((SafeDeleteSslContext)context).SslContext); } public static SecurityStatusPal EncryptMessage( @@ -72,7 +73,37 @@ public static SecurityStatusPal EncryptMessage( resultSize = 0; Debug.Assert(input.Length > 0, $"{nameof(input.Length)} > 0 since {nameof(CanEncryptEmptyMessage)} is false"); - throw new NotImplementedException(nameof(EncryptMessage)); + try + { + SafeDeleteSslContext sslContext = (SafeDeleteSslContext)securityContext; + SafeSslHandle sslHandle = sslContext.SslContext; + + PAL_SSLStreamStatus ret = Interop.AndroidCrypto.SSLStreamWrite(sslHandle, input); + SecurityStatusPalErrorCode statusCode = ret switch + { + PAL_SSLStreamStatus.OK => SecurityStatusPalErrorCode.OK, + PAL_SSLStreamStatus.NeedData => SecurityStatusPalErrorCode.ContinueNeeded, + PAL_SSLStreamStatus.Renegotiate => SecurityStatusPalErrorCode.Renegotiate, + PAL_SSLStreamStatus.Closed => SecurityStatusPalErrorCode.ContextExpired, + _ => SecurityStatusPalErrorCode.InternalError + }; + + if (sslContext.BytesReadyForConnection <= output?.Length) + { + resultSize = sslContext.ReadPendingWrites(output, 0, output.Length); + } + else + { + output = sslContext.ReadPendingWrites()!; + resultSize = output.Length; + } + + return new SecurityStatusPal(statusCode); + } + catch (Exception e) + { + return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, e); + } } public static SecurityStatusPal DecryptMessage( @@ -81,7 +112,34 @@ public static SecurityStatusPal DecryptMessage( ref int offset, ref int count) { - throw new NotImplementedException(nameof(DecryptMessage)); + try + { + SafeDeleteSslContext sslContext = (SafeDeleteSslContext)securityContext; + SafeSslHandle sslHandle = sslContext.SslContext; + + sslContext.Write(buffer.AsSpan(offset, count)); + + PAL_SSLStreamStatus ret = Interop.AndroidCrypto.SSLStreamRead(sslHandle, buffer.AsSpan(offset, count), out int read); + if (ret == PAL_SSLStreamStatus.Error) + return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError); + + count = read; + + SecurityStatusPalErrorCode statusCode = ret switch + { + PAL_SSLStreamStatus.OK => SecurityStatusPalErrorCode.OK, + PAL_SSLStreamStatus.NeedData => SecurityStatusPalErrorCode.OK, + PAL_SSLStreamStatus.Renegotiate => SecurityStatusPalErrorCode.Renegotiate, + PAL_SSLStreamStatus.Closed => SecurityStatusPalErrorCode.ContextExpired, + _ => SecurityStatusPalErrorCode.InternalError + }; + + return new SecurityStatusPal(statusCode); + } + catch (Exception e) + { + return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, e); + } } public static ChannelBinding? QueryContextChannelBinding( @@ -134,12 +192,17 @@ private static SecurityStatusPal HandshakeInternal( SafeSslHandle sslHandle = sslContext!.SslContext; - // Do handshake - // Interop.AndroidCrypto.SSLStreamHandshake + PAL_SSLStreamStatus ret = Interop.AndroidCrypto.SSLStreamHandshake(sslHandle); + SecurityStatusPalErrorCode statusCode = ret switch + { + PAL_SSLStreamStatus.OK => SecurityStatusPalErrorCode.OK, + PAL_SSLStreamStatus.NeedData => SecurityStatusPalErrorCode.ContinueNeeded, + _ => SecurityStatusPalErrorCode.InternalError + }; outputBuffer = sslContext.ReadPendingWrites(); - return new SecurityStatusPal(SecurityStatusPalErrorCode.OK); + return new SecurityStatusPal(statusCode); } catch (Exception exc) { @@ -166,8 +229,7 @@ public static SecurityStatusPal ApplyShutdownToken( SafeSslHandle sslHandle = sslContext.SslContext; - // bool success = Interop.AndroidCrypto.SslShutdown(sslHandle); - bool success = true; + bool success = Interop.AndroidCrypto.SSLStreamShutdown(sslHandle); if (success) { return new SecurityStatusPal(SecurityStatusPalErrorCode.OK);