From c5824f39e18733937f20cdfe8f03c31b9606411b Mon Sep 17 00:00:00 2001 From: Robin Rolf Date: Sat, 26 Oct 2024 17:58:24 +0000 Subject: [PATCH] feat: ThreadedEncryptionTransport --- .../Editor/EncryptionTransportInspector.cs | 3 + .../Encryption/EncryptedConnection.cs | 58 ++-- .../Encryption/ThreadedEncryptionTransport.cs | 314 ++++++++++++++++++ .../ThreadedEncryptionTransport.cs.meta | 11 + 4 files changed, 360 insertions(+), 26 deletions(-) create mode 100644 Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs create mode 100644 Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs index cb50bf2b3..6c39f909f 100644 --- a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs @@ -77,5 +77,8 @@ public override void OnInspectorGUI() serializedObject.ApplyModifiedProperties(); } + + [CustomEditor(typeof(ThreadedEncryptionTransport), true)] + class EncryptionThreadedTransportInspector : EncryptionTransportInspector {} } } diff --git a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs index 24a52f3ca..f2a18ca63 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs @@ -1,6 +1,7 @@ using System; using System.Security.Cryptography; using System.Text; +using System.Threading; using Mirror.BouncyCastle.Crypto; using Mirror.BouncyCastle.Crypto.Agreement; using Mirror.BouncyCastle.Crypto.Digests; @@ -48,20 +49,19 @@ public class EncryptedConnection // Set up a global cipher instance, it is initialised/reset before use // (AesFastEngine used to exist, but was removed due to side channel issues) // use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core) - static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); + static readonly ThreadLocal Cipher = new ThreadLocal(() => new GcmBlockCipher(AesUtilities.CreateEngine())); // Set up a global HKDF with a SHA-256 digest - static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest()); + static readonly ThreadLocal Hkdf = new ThreadLocal(() => new HkdfBytesGenerator(new Sha256Digest())); // Global byte array to store nonce sent by the remote side, they're used immediately after - static readonly byte[] ReceiveNonce = new byte[NonceSize]; + static readonly ThreadLocal ReceiveNonce = new ThreadLocal(() => new byte[NonceSize]); // Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes* - static readonly byte[] TMPRemoteSaltBuffer = new byte[HkdfSaltSize]; + static readonly ThreadLocal TMPRemoteSaltBuffer = new ThreadLocal(() => new byte[HkdfSaltSize]); + // buffer for encrypt/decrypt operations, resized larger as needed - // this is also the buffer that will be returned to mirror via ArraySegment - // so any thread safety concerns would need to take extra care here - static byte[] TMPCryptBuffer = new byte[2048]; + static ThreadLocal TMPCryptBuffer = new ThreadLocal(() => new byte[2048]); // packet headers enum OpCodes : byte @@ -189,7 +189,7 @@ public void OnReceiveRaw(ArraySegment data, int channel) } ArraySegment ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize); - reader.ReadBytes(ReceiveNonce, NonceSize); + reader.ReadBytes(ReceiveNonce.Value, NonceSize); Profiler.BeginSample("EncryptedConnection.Decrypt"); ArraySegment plaintext = Decrypt(ciphertext); @@ -233,8 +233,8 @@ public void OnReceiveRaw(ArraySegment data, int channel) state = State.WaitingHandshakeReply; ResetTimeouts(); - reader.ReadBytes(TMPRemoteSaltBuffer, HkdfSaltSize); - CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer); + reader.ReadBytes(TMPRemoteSaltBuffer.Value, HkdfSaltSize); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer.Value); SendHandshakeFin(); break; case OpCodes.HandshakeFin: @@ -307,25 +307,28 @@ ArraySegment Encrypt(ArraySegment plaintext) // Need to make the nonce unique again before encrypting another message UpdateNonce(); // Re-initialize the cipher with our cached parameters - Cipher.Init(true, cipherParametersEncrypt); + Cipher.Value.Init(true, cipherParametersEncrypt); // Calculate the expected output size, this should always be input size + mac size - int outSize = Cipher.GetOutputSize(plaintext.Count); + int outSize = Cipher.Value.GetOutputSize(plaintext.Count); #if UNITY_EDITOR // expecting the outSize to be input size + MacSize if (outSize != plaintext.Count + MacSizeBytes) throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}"); #endif // Resize the static buffer to fit - EnsureSize(ref TMPCryptBuffer, outSize); + byte[] cryptBuffer = TMPCryptBuffer.Value; + EnsureSize(ref cryptBuffer, outSize); + TMPCryptBuffer.Value = cryptBuffer; + int resultLen; try { // Run the plain text through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, TMPCryptBuffer, 0); + Cipher.Value.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, cryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac) - resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); + resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types // @@ -339,7 +342,7 @@ ArraySegment Encrypt(ArraySegment plaintext) if (resultLen != outSize) throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); #endif - return new ArraySegment(TMPCryptBuffer, 0, resultLen); + return new ArraySegment(cryptBuffer, 0, resultLen); } ArraySegment Decrypt(ArraySegment ciphertext) @@ -351,25 +354,28 @@ ArraySegment Decrypt(ArraySegment ciphertext) return new ArraySegment(); } // Re-initialize the cipher with our cached parameters - Cipher.Init(false, cipherParametersDecrypt); + Cipher.Value.Init(false, cipherParametersDecrypt); // Calculate the expected output size, this should always be input size - mac size - int outSize = Cipher.GetOutputSize(ciphertext.Count); + int outSize = Cipher.Value.GetOutputSize(ciphertext.Count); #if UNITY_EDITOR // expecting the outSize to be input size - MacSize if (outSize != ciphertext.Count - MacSizeBytes) throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}"); #endif - // Resize the static buffer to fit - EnsureSize(ref TMPCryptBuffer, outSize); + + byte[] cryptBuffer = TMPCryptBuffer.Value; + EnsureSize(ref cryptBuffer, outSize); + TMPCryptBuffer.Value = cryptBuffer; + int resultLen; try { // Run the ciphertext through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, TMPCryptBuffer, 0); + Cipher.Value.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, cryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac) - resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); + resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types catch (Exception e) @@ -382,7 +388,7 @@ ArraySegment Decrypt(ArraySegment ciphertext) if (resultLen != outSize) throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); #endif - return new ArraySegment(TMPCryptBuffer, 0, resultLen); + return new ArraySegment(cryptBuffer, 0, resultLen); } void UpdateNonce() @@ -477,13 +483,13 @@ void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) return; } - Hkdf.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo)); + Hkdf.Value.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo)); // Allocate a buffer for the output key byte[] keyRaw = new byte[KeyLength]; // Generate the output keying material - Hkdf.GenerateBytes(keyRaw, 0, keyRaw.Length); + Hkdf.Value.GenerateBytes(keyRaw, 0, keyRaw.Length); KeyParameter key = new KeyParameter(keyRaw); @@ -493,7 +499,7 @@ void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) // we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance // instead of creating a new one each encrypt/decrypt cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, nonce); - cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); + cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce.Value); } /** diff --git a/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs new file mode 100644 index 000000000..0da64b528 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs @@ -0,0 +1,314 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using Mirror.BouncyCastle.Crypto; +using UnityEngine; +using UnityEngine.Profiling; +using UnityEngine.Serialization; +using Debug = UnityEngine.Debug; + +namespace Mirror.Transports.Encryption +{ + [HelpURL("https://mirror-networking.gitbook.io/docs/manual/transports/encryption-transport")] + public class ThreadedEncryptionTransport : ThreadedTransport, PortTransport + { + public override bool IsEncrypted => true; + public override string EncryptionCipher => "AES256-GCM"; + [FormerlySerializedAs("inner")] + public ThreadedTransport Inner; + + public ushort Port + { + get + { + if (Inner is PortTransport portTransport) + return portTransport.Port; + + Debug.LogError($"ThreadedEncryptionTransport can't get Port because {Inner} is not a PortTransport"); + return 0; + } + set + { + if (Inner is PortTransport portTransport) + { + portTransport.Port = value; + return; + } + Debug.LogError($"ThreadedEncryptionTransport can't set Port because {Inner} is not a PortTransport"); + } + } + + public enum ValidationMode + { + Off, + List, + Callback + } + + [FormerlySerializedAs("clientValidateServerPubKey")] + public ValidationMode ClientValidateServerPubKey; + [FormerlySerializedAs("clientTrustedPubKeySignatures")] + [Tooltip("List of public key fingerprints the client will accept")] + public string[] ClientTrustedPubKeySignatures; + /// + /// Called when a client connects to a server + /// ATTENTION: NOT THREAD SAFE. + /// This will be called on the worker thread. + /// + public Func OnClientValidateServerPubKey; + [FormerlySerializedAs("serverLoadKeyPairFromFile")] + public bool ServerLoadKeyPairFromFile; + [FormerlySerializedAs("serverKeypairPath")] + public string ServerKeypairPath = "./server-keys.json"; + + EncryptedConnection client; + + readonly Dictionary serverConnections = new Dictionary(); + + readonly List serverPendingConnections = + new List(); + + EncryptionCredentials credentials; + public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint; + public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized; + + // Used for threaded time keeping as unitys Time.time is not thread safe + Stopwatch stopwatch = Stopwatch.StartNew(); + + void ServerRemoveFromPending(EncryptedConnection con) + { + for (int i = 0; i < serverPendingConnections.Count; i++) + if (serverPendingConnections[i] == con) + { + // remove by swapping with last + int lastIndex = serverPendingConnections.Count - 1; + serverPendingConnections[i] = serverPendingConnections[lastIndex]; + serverPendingConnections.RemoveAt(lastIndex); + break; + } + } + + void HandleInnerServerDisconnected(int connId) + { + if (serverConnections.TryGetValue(connId, out EncryptedConnection con)) + { + ServerRemoveFromPending(con); + serverConnections.Remove(connId); + } + OnThreadedServerDisconnected(connId); + } + + void HandleInnerServerError(int connId, TransportError type, string msg) => OnThreadedServerError(connId, type, $"inner: {msg}"); + + void HandleInnerServerDataReceived(int connId, ArraySegment data, int channel) + { + if (serverConnections.TryGetValue(connId, out EncryptedConnection c)) + c.OnReceiveRaw(data, channel); + } + + void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId)); + + void HandleInnerServerConnected(int connId, string clientRemoteAddress) + { + Debug.Log($"[ThreadedEncryptionTransport] New connection #{connId} from {clientRemoteAddress}"); + EncryptedConnection ec = null; + ec = new EncryptedConnection( + credentials, + false, + (segment, channel) => Inner.ServerSend(connId, segment, channel), + (segment, channel) => OnThreadedServerReceive(connId, segment, channel), + () => + { + Debug.Log($"[ThreadedEncryptionTransport] Connection #{connId} is ready"); + // ReSharper disable once AccessToModifiedClosure + ServerRemoveFromPending(ec); + OnThreadedServerConnected(connId, new IPEndPoint(IPAddress.Parse(clientRemoteAddress), 0)); + }, + (type, msg) => + { + OnThreadedServerError(connId, type, msg); + ServerDisconnect(connId); + }); + serverConnections.Add(connId, ec); + serverPendingConnections.Add(ec); + } + + void HandleInnerClientDisconnected() + { + client = null; + OnThreadedClientDisconnected(); + } + + void HandleInnerClientError(TransportError arg1, string arg2) => OnThreadedClientError(arg1, $"inner: {arg2}"); + + void HandleInnerClientDataReceived(ArraySegment data, int channel) => client?.OnReceiveRaw(data, channel); + + void HandleInnerClientConnected() => + client = new EncryptedConnection( + credentials, + true, + (segment, channel) => Inner.ClientSend(segment, channel), + (segment, channel) => OnThreadedClientReceive(segment, channel), + () => + { + OnThreadedClientConnected(); + }, + (type, msg) => + { + OnThreadedClientError(type, msg); + ClientDisconnect(); + }, + HandleClientValidateServerPubKey); + + bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) + { + switch (ClientValidateServerPubKey) + { + case ValidationMode.Off: + return true; + case ValidationMode.List: + return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; + case ValidationMode.Callback: + return OnClientValidateServerPubKey(pubKeyInfo); + default: + throw new ArgumentOutOfRangeException(); + } + } + + protected override void Awake() + { + base.Awake(); + // check if encryption via hardware acceleration is supported. + // this can be useful to know for low end devices. + // + // hardware acceleration requires netcoreapp3.0 or later: + // https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/AesUtilities.cs#L18 + // because AesEngine_x86 requires System.Runtime.Intrinsics.X86: + // https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs + // which Unity does not support yet. + Debug.Log($"ThreadedEncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}"); + } + + public override bool Available() => Inner.Available(); + + protected override void ThreadedClientConnect(string address) + { + switch (ClientValidateServerPubKey) + { + case ValidationMode.Off: + break; + case ValidationMode.List: + if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0) + { + OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty."); + return; + } + break; + case ValidationMode.Callback: + if (OnClientValidateServerPubKey == null) + { + OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set"); + return; + } + break; + default: + throw new ArgumentOutOfRangeException(); + } + credentials = EncryptionCredentials.Generate(); + Inner.OnClientConnected = HandleInnerClientConnected; + Inner.OnClientDataReceived = HandleInnerClientDataReceived; + Inner.OnClientDataSent = (bytes, channel) => OnThreadedClientSend(bytes, channel); + Inner.OnClientError = HandleInnerClientError; + Inner.OnClientDisconnected = HandleInnerClientDisconnected; + Inner.ClientConnect(address); + } + + protected override void ThreadedClientConnect(Uri address) => Inner.ClientConnect(address); + + protected override void ThreadedClientSend(ArraySegment segment, int channelId) => + client?.Send(segment, channelId); + + protected override void ThreadedClientDisconnect() => Inner.ClientDisconnect(); + + protected override void ThreadedServerStart() + { + if (ServerLoadKeyPairFromFile) + credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath); + else + credentials = EncryptionCredentials.Generate(); +#pragma warning disable CS0618 // Type or member is obsolete + Inner.OnServerConnected = HandleInnerServerConnected; +#pragma warning restore CS0618 // Type or member is obsolete + Inner.OnServerConnectedWithAddress = HandleInnerServerConnected; + Inner.OnServerDataReceived = HandleInnerServerDataReceived; + Inner.OnServerDataSent = (connId, bytes, channel) => OnThreadedServerSend(connId, bytes, channel); + Inner.OnServerError = HandleInnerServerError; + Inner.OnServerDisconnected = HandleInnerServerDisconnected; + Inner.ServerStart(); + } + + protected override void ThreadedServerSend(int connectionId, ArraySegment segment, int channelId) + { + if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) + connection.Send(segment, channelId); + } + + protected override void ThreadedServerDisconnect(int connectionId) => + // cleanup is done via inners disconnect event + Inner.ServerDisconnect(connectionId); + + protected override void ThreadedClientEarlyUpdate() {} + + protected override void ThreadedServerStop() => Inner.ServerStop(); + + public override Uri ServerUri() => Inner.ServerUri(); + + public override int GetMaxPacketSize(int channelId = Channels.Reliable) => + Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; + + protected override void ThreadedShutdown() => Inner.Shutdown(); + + public override void ClientEarlyUpdate() + { + base.ClientEarlyUpdate(); + Inner.ClientEarlyUpdate(); + } + + public override void ClientLateUpdate() + { + base.ClientLateUpdate(); + Inner.ClientLateUpdate(); + } + + protected override void ThreadedClientLateUpdate() + { + Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate"); + client?.TickNonReady(stopwatch.Elapsed.TotalSeconds); + Profiler.EndSample(); + } + + protected override void ThreadedServerEarlyUpdate() {} + + public override void ServerEarlyUpdate() + { + base.ServerEarlyUpdate(); + Inner.ServerEarlyUpdate(); + } + + public override void ServerLateUpdate() + { + base.ServerLateUpdate(); + Inner.ServerLateUpdate(); + } + + protected override void ThreadedServerLateUpdate() + { + Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate"); + // Reverse iteration as entries can be removed while updating + for (int i = serverPendingConnections.Count - 1; i >= 0; i--) + serverPendingConnections[i].TickNonReady(stopwatch.Elapsed.TotalSeconds); + Profiler.EndSample(); + } + } +} diff --git a/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta new file mode 100644 index 000000000..44448f3ff --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5d3e310924fb49c195391b9699f20809 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3} + userData: + assetBundleName: + assetBundleVariant: