From 7ebbaa031923f34e435199f2dfb4c70f6e7c2531 Mon Sep 17 00:00:00 2001 From: Robin Rolf Date: Sat, 26 Oct 2024 16:29:19 +0000 Subject: [PATCH 1/2] code style --- .../EncryptionTransportTransportTest.cs | 2 +- .../Editor/EncryptionTransportInspector.cs | 10 +- .../Encryption/EncryptedConnection.cs | 287 ++++++++---------- .../Encryption/EncryptionCredentials.cs | 18 +- .../Encryption/EncryptionTransport.cs | 220 ++++++-------- 5 files changed, 227 insertions(+), 310 deletions(-) diff --git a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs index 269ce1b15..1fa25ffc5 100644 --- a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs +++ b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs @@ -22,7 +22,7 @@ public void Setup() GameObject gameObject = new GameObject(); encryption = gameObject.AddComponent(); - encryption.inner = inner; + encryption.Inner = inner; } [TearDown] diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs index 24557f910..cb50bf2b3 100644 --- a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs @@ -17,11 +17,11 @@ public class EncryptionTransportInspector : UnityEditor.Editor void OnEnable() { - innerProperty = serializedObject.FindProperty("inner"); - clientValidatesServerPubKeyProperty = serializedObject.FindProperty("clientValidateServerPubKey"); - clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("clientTrustedPubKeySignatures"); - serverKeypairPathProperty = serializedObject.FindProperty("serverKeypairPath"); - serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("serverLoadKeyPairFromFile"); + innerProperty = serializedObject.FindProperty("Inner"); + clientValidatesServerPubKeyProperty = serializedObject.FindProperty("ClientValidateServerPubKey"); + clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("ClientTrustedPubKeySignatures"); + serverKeypairPathProperty = serializedObject.FindProperty("ServerKeypairPath"); + serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("ServerLoadKeyPairFromFile"); } public override void OnInspectorGUI() diff --git a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs index 46c916310..24a52f3ca 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs @@ -14,32 +14,32 @@ namespace Mirror.Transports.Encryption public class EncryptedConnection { // 256-bit key - private const int KeyLength = 32; + const int KeyLength = 32; // 512-bit salt for the key derivation function - private const int HkdfSaltSize = KeyLength * 2; + const int HkdfSaltSize = KeyLength * 2; // Info tag for the HKDF, this just adds more entropy - private static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport"); + static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport"); // fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed) - private const int NonceSize = 12; + const int NonceSize = 12; // this is the size of the "checksum" included in each encrypted payload // 16 bytes/128 bytes is the recommended value for best security // can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker. // Setting it lower than 12 bytes is not recommended - private const int MacSizeBytes = 16; + const int MacSizeBytes = 16; - private const int MacSizeBits = MacSizeBytes * 8; + const int MacSizeBits = MacSizeBytes * 8; // How much metadata overhead we have for regular packets public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize; // After how many seconds of not receiving a handshake packet we should time out - private const double DurationTimeout = 2; // 2s + const double DurationTimeout = 2; // 2s // After how many seconds to assume the last handshake packet got lost and to resend another one - private const double DurationResend = 0.05; // 50ms + const double DurationResend = 0.05; // 50ms // Static fields for allocation efficiency, makes this not thread safe @@ -48,20 +48,20 @@ public class EncryptedConnection // Set up a global cipher instance, it is initialised/reset before use // (AesFastEngine used to exist, but was removed due to side channel issues) // use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core) - private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); + static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); // Set up a global HKDF with a SHA-256 digest - private static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest()); + static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest()); // Global byte array to store nonce sent by the remote side, they're used immediately after - private static readonly byte[] ReceiveNonce = new byte[NonceSize]; + static readonly byte[] ReceiveNonce = new byte[NonceSize]; // Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes* - private static byte[] _tmpRemoteSaltBuffer = new byte[HkdfSaltSize]; + static readonly byte[] TMPRemoteSaltBuffer = new byte[HkdfSaltSize]; // buffer for encrypt/decrypt operations, resized larger as needed // this is also the buffer that will be returned to mirror via ArraySegment // so any thread safety concerns would need to take extra care here - private static byte[] _tmpCryptBuffer = new byte[2048]; + static byte[] TMPCryptBuffer = new byte[2048]; // packet headers enum OpCodes : byte @@ -70,7 +70,7 @@ enum OpCodes : byte Data = 1, HandshakeStart = 2, HandshakeAck = 3, - HandshakeFin = 4, + HandshakeFin = 4 } enum State @@ -91,37 +91,37 @@ enum State Ready } - private State _state = State.WaitingHandshake; + State state = State.WaitingHandshake; // Key exchange confirmed and data can be sent freely - public bool IsReady => _state == State.Ready; + public bool IsReady => state == State.Ready; // Callback to send off encrypted data - private Action, int> _send; + readonly Action, int> send; // Callback when received data has been decrypted - private Action, int> _receive; + readonly Action, int> receive; // Callback when the connection becomes ready - private Action _ready; + readonly Action ready; // On-error callback, disconnect expected - private Action _error; + readonly Action error; // Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance // (usually client validates the server key) - private Func _validateRemoteKey; + readonly Func validateRemoteKey; // Our asymmetric credentials for the initial DH exchange - private EncryptionCredentials _credentials; - private byte[] _hkdfSalt; + EncryptionCredentials credentials; + readonly byte[] hkdfSalt; // After no handshake packet in this many seconds, the handshake fails - private double _handshakeTimeout; + double handshakeTimeout; // When to assume the last handshake packet got lost and to resend another one - private double _nextHandshakeResend; + double nextHandshakeResend; // we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in // so we can update it without creating a new AeadParameters instance // this might break in the future! (will cause bad data) - private byte[] _nonce = new byte[NonceSize]; - private AeadParameters _cipherParametersEncrypt; - private AeadParameters _cipherParametersDecrypt; + byte[] nonce = new byte[NonceSize]; + AeadParameters cipherParametersEncrypt; + AeadParameters cipherParametersDecrypt; /* @@ -130,7 +130,7 @@ enum State * * The client does this, since the fin is not acked explicitly, but by receiving data to decrypt */ - private readonly bool _sendsFirst; + readonly bool sendsFirst; public EncryptedConnection(EncryptionCredentials credentials, bool isClient, @@ -140,28 +140,24 @@ public EncryptedConnection(EncryptionCredentials credentials, Action errorAction, Func validateRemoteKey = null) { - _credentials = credentials; - _sendsFirst = isClient; - if (!_sendsFirst) - { + this.credentials = credentials; + sendsFirst = isClient; + if (!sendsFirst) // salt is controlled by the server - _hkdfSalt = GenerateSecureBytes(HkdfSaltSize); - } - _send = sendAction; - _receive = receiveAction; - _ready = readyAction; - _error = errorAction; - _validateRemoteKey = validateRemoteKey; + hkdfSalt = GenerateSecureBytes(HkdfSaltSize); + send = sendAction; + receive = receiveAction; + ready = readyAction; + error = errorAction; + this.validateRemoteKey = validateRemoteKey; } // Generates a random starting nonce - private static byte[] GenerateSecureBytes(int size) + static byte[] GenerateSecureBytes(int size) { byte[] bytes = new byte[size]; using (RandomNumberGenerator rng = RandomNumberGenerator.Create()) - { rng.GetBytes(bytes); - } return bytes; } @@ -170,7 +166,7 @@ public void OnReceiveRaw(ArraySegment data, int channel) { if (data.Count < 1) { - _error(TransportError.Unexpected, "Received empty packet"); + error(TransportError.Unexpected, "Received empty packet"); return; } @@ -181,18 +177,14 @@ public void OnReceiveRaw(ArraySegment data, int channel) { case OpCodes.Data: // first sender ready is implicit when data is received - if (_sendsFirst && _state == State.WaitingHandshakeReply) - { + if (sendsFirst && state == State.WaitingHandshakeReply) SetReady(); - } else if (!IsReady) - { - _error(TransportError.Unexpected, "Unexpected data while not ready."); - } + error(TransportError.Unexpected, "Unexpected data while not ready."); if (reader.Remaining < Overhead) { - _error(TransportError.Unexpected, "received data packet smaller than metadata size"); + error(TransportError.Unexpected, "received data packet smaller than metadata size"); return; } @@ -203,72 +195,62 @@ public void OnReceiveRaw(ArraySegment data, int channel) ArraySegment plaintext = Decrypt(ciphertext); Profiler.EndSample(); if (plaintext.Count == 0) - { // error return; - } - _receive(plaintext, channel); + receive(plaintext, channel); break; case OpCodes.HandshakeStart: - if (_sendsFirst) + if (sendsFirst) { - _error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this."); + error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this."); return; } - if (_state == State.WaitingHandshakeReply) - { + if (state == State.WaitingHandshakeReply) // this is fine, packets may arrive out of order return; - } - _state = State.WaitingHandshakeReply; + state = State.WaitingHandshakeReply; ResetTimeouts(); - CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _hkdfSalt); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining), hkdfSalt); SendHandshakeAndPubKey(OpCodes.HandshakeAck); break; case OpCodes.HandshakeAck: - if (!_sendsFirst) + if (!sendsFirst) { - _error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this."); + error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this."); return; } if (IsReady) - { // this is fine, packets may arrive out of order return; - } - if (_state == State.WaitingHandshakeReply) - { + if (state == State.WaitingHandshakeReply) // this is fine, packets may arrive out of order return; - } - _state = State.WaitingHandshakeReply; + state = State.WaitingHandshakeReply; ResetTimeouts(); - reader.ReadBytes(_tmpRemoteSaltBuffer, HkdfSaltSize); - CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _tmpRemoteSaltBuffer); + reader.ReadBytes(TMPRemoteSaltBuffer, HkdfSaltSize); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer); SendHandshakeFin(); break; case OpCodes.HandshakeFin: - if (_sendsFirst) + if (sendsFirst) { - _error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this."); + error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this."); return; } if (IsReady) - { // this is fine, packets may arrive out of order return; - } - if (_state != State.WaitingHandshakeReply) + if (state != State.WaitingHandshakeReply) { - _error(TransportError.Unexpected, + error(TransportError.Unexpected, "Received HandshakeFin packet, we didn't expect this yet."); return; } @@ -277,24 +259,25 @@ public void OnReceiveRaw(ArraySegment data, int channel) break; default: - _error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}"); + error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}"); break; } } } - private void SetReady() + + void SetReady() { // done with credentials, null out the reference - _credentials = null; + credentials = null; - _state = State.Ready; - _ready(); + state = State.Ready; + ready(); } - private void ResetTimeouts() + void ResetTimeouts() { - _handshakeTimeout = 0; - _nextHandshakeResend = -1; + handshakeTimeout = 0; + nextHandshakeResend = -1; } public void Send(ArraySegment data, int channel) @@ -307,161 +290,143 @@ public void Send(ArraySegment data, int channel) Profiler.EndSample(); if (encrypted.Count == 0) - { // error return; - } writer.WriteBytes(encrypted.Array, 0, encrypted.Count); // write nonce after since Encrypt will update it - writer.WriteBytes(_nonce, 0, NonceSize); - _send(writer.ToArraySegment(), channel); + writer.WriteBytes(nonce, 0, NonceSize); + send(writer.ToArraySegment(), channel); } } - private ArraySegment Encrypt(ArraySegment plaintext) + ArraySegment Encrypt(ArraySegment plaintext) { if (plaintext.Count == 0) - { // Invalid return new ArraySegment(); - } // Need to make the nonce unique again before encrypting another message UpdateNonce(); // Re-initialize the cipher with our cached parameters - Cipher.Init(true, _cipherParametersEncrypt); + Cipher.Init(true, cipherParametersEncrypt); // Calculate the expected output size, this should always be input size + mac size int outSize = Cipher.GetOutputSize(plaintext.Count); #if UNITY_EDITOR // expecting the outSize to be input size + MacSize if (outSize != plaintext.Count + MacSizeBytes) - { throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}"); - } #endif // Resize the static buffer to fit - EnsureSize(ref _tmpCryptBuffer, outSize); + EnsureSize(ref TMPCryptBuffer, outSize); int resultLen; try { // Run the plain text through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, _tmpCryptBuffer, 0); + Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, TMPCryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac) - resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); + resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types // catch (Exception e) { - _error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}"); + error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}"); return new ArraySegment(); } #if UNITY_EDITOR // expecting the result length to match the previously calculated input size + MacSize if (resultLen != outSize) - { throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); - } #endif - return new ArraySegment(_tmpCryptBuffer, 0, resultLen); + return new ArraySegment(TMPCryptBuffer, 0, resultLen); } - private ArraySegment Decrypt(ArraySegment ciphertext) + ArraySegment Decrypt(ArraySegment ciphertext) { if (ciphertext.Count <= MacSizeBytes) { - _error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})"); + error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})"); // Invalid return new ArraySegment(); } // Re-initialize the cipher with our cached parameters - Cipher.Init(false, _cipherParametersDecrypt); + Cipher.Init(false, cipherParametersDecrypt); // Calculate the expected output size, this should always be input size - mac size int outSize = Cipher.GetOutputSize(ciphertext.Count); #if UNITY_EDITOR // expecting the outSize to be input size - MacSize if (outSize != ciphertext.Count - MacSizeBytes) - { throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}"); - } #endif // Resize the static buffer to fit - EnsureSize(ref _tmpCryptBuffer, outSize); + EnsureSize(ref TMPCryptBuffer, outSize); int resultLen; try { // Run the ciphertext through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, _tmpCryptBuffer, 0); + Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, TMPCryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac) - resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); + resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types catch (Exception e) { - _error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data"); + error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data"); return new ArraySegment(); } #if UNITY_EDITOR // expecting the result length to match the previously calculated input size + MacSize if (resultLen != outSize) - { throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); - } #endif - return new ArraySegment(_tmpCryptBuffer, 0, resultLen); + return new ArraySegment(TMPCryptBuffer, 0, resultLen); } - private void UpdateNonce() + void UpdateNonce() { // increment the nonce by one // we need to ensure the nonce is *always* unique and not reused // easiest way to do this is by simply incrementing it for (int i = 0; i < NonceSize; i++) { - _nonce[i]++; - if (_nonce[i] != 0) - { + nonce[i]++; + if (nonce[i] != 0) break; - } } } - private static void EnsureSize(ref byte[] buffer, int size) + static void EnsureSize(ref byte[] buffer, int size) { if (buffer.Length < size) - { // double buffer to avoid constantly resizing by a few bytes Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2)); - } } - private void SendHandshakeAndPubKey(OpCodes opcode) + void SendHandshakeAndPubKey(OpCodes opcode) { using (NetworkWriterPooled writer = NetworkWriterPool.Get()) { writer.WriteByte((byte)opcode); if (opcode == OpCodes.HandshakeAck) - { - writer.WriteBytes(_hkdfSalt, 0, HkdfSaltSize); - } - writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length); - _send(writer.ToArraySegment(), Channels.Unreliable); + writer.WriteBytes(hkdfSalt, 0, HkdfSaltSize); + writer.WriteBytes(credentials.PublicKeySerialized, 0, credentials.PublicKeySerialized.Length); + send(writer.ToArraySegment(), Channels.Unreliable); } } - private void SendHandshakeFin() + void SendHandshakeFin() { using (NetworkWriterPooled writer = NetworkWriterPool.Get()) { writer.WriteByte((byte)OpCodes.HandshakeFin); - _send(writer.ToArraySegment(), Channels.Unreliable); + send(writer.ToArraySegment(), Channels.Unreliable); } } - private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) + void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) { AsymmetricKeyParameter remotePubKey; try @@ -470,11 +435,11 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) } catch (Exception e) { - _error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}"); + error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}"); return; } - if (_validateRemoteKey != null) + if (validateRemoteKey != null) { PubKeyInfo info = new PubKeyInfo { @@ -482,9 +447,9 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) Serialized = remotePubKeyRaw, Key = remotePubKey }; - if (!_validateRemoteKey(info)) + if (!validateRemoteKey(info)) { - _error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. "); + error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. "); return; } } @@ -493,7 +458,7 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) // This gives us the same key on the other side, with our public key and their remote // It's like magic, but with math! ECDHBasicAgreement ecdh = new ECDHBasicAgreement(); - ecdh.Init(_credentials.PrivateKey); + ecdh.Init(credentials.PrivateKey); byte[] sharedSecret; try { @@ -502,13 +467,13 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) catch (Exception e) { - _error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}"); + error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}"); return; } if (salt.Length != HkdfSaltSize) { - _error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}."); + error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}."); return; } @@ -523,12 +488,12 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) KeyParameter key = new KeyParameter(keyRaw); // generate a starting nonce - _nonce = GenerateSecureBytes(NonceSize); + nonce = GenerateSecureBytes(NonceSize); // we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance // instead of creating a new one each encrypt/decrypt - _cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, _nonce); - _cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); + cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, nonce); + cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); } /** @@ -537,53 +502,41 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) public void TickNonReady(double time) { if (IsReady) + return; + + // Timeout reset + if (handshakeTimeout == 0) + handshakeTimeout = time + DurationTimeout; + else if (time > handshakeTimeout) { + error?.Invoke(TransportError.Timeout, $"Timed out during {state}, this probably just means the other side went away which is fine."); return; } // Timeout reset - if (_handshakeTimeout == 0) + if (nextHandshakeResend < 0) { - _handshakeTimeout = time + DurationTimeout; - } - else if (time > _handshakeTimeout) - { - _error?.Invoke(TransportError.Timeout, $"Timed out during {_state}, this probably just means the other side went away which is fine."); + nextHandshakeResend = time + DurationResend; return; } - // Timeout reset - if (_nextHandshakeResend < 0) - { - _nextHandshakeResend = time + DurationResend; - return; - } - - if (time < _nextHandshakeResend) - { + if (time < nextHandshakeResend) // Resend isn't due yet return; - } - _nextHandshakeResend = time + DurationResend; - switch (_state) + nextHandshakeResend = time + DurationResend; + switch (state) { case State.WaitingHandshake: - if (_sendsFirst) - { + if (sendsFirst) SendHandshakeAndPubKey(OpCodes.HandshakeStart); - } break; case State.WaitingHandshakeReply: - if (_sendsFirst) - { + if (sendsFirst) SendHandshakeFin(); - } else - { SendHandshakeAndPubKey(OpCodes.HandshakeAck); - } break; case State.Ready: // IsReady is checked above & early-returned diff --git a/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs index f12e4c919..5834366e5 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs @@ -51,13 +51,11 @@ public static byte[] SerializePublicKey(AsymmetricKeyParameter publicKey) return publicKeyInfo.ToAsn1Object().GetDerEncoded(); } - public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment pubKey) - { + public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment pubKey) => // And then we do this to deserialize from the DER (from above) // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted // to a byte[] first and then shoved through a MemoryStream - return PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false)); - } + PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false)); public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey) { @@ -66,13 +64,11 @@ public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey) return privateKeyInfo.ToAsn1Object().GetDerEncoded(); } - public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment privateKey) - { + public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment privateKey) => // And then we do this to deserialize from the DER (from above) // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted // to a byte[] first and then shoved through a MemoryStream - return PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false)); - } + PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false)); public static string PubKeyFingerprint(ArraySegment publicKeyBytes) { @@ -90,7 +86,7 @@ public void SaveToFile(string path) { PublicKeyFingerprint = PublicKeyFingerprint, PublicKey = Convert.ToBase64String(PublicKeySerialized), - PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)), + PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)) }); File.WriteAllText(path, json); } @@ -104,9 +100,7 @@ public static EncryptionCredentials LoadFromFile(string path) byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey); if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment(publicKeyBytes))) - { throw new Exception("Saved public key fingerprint does not match public key."); - } return new EncryptionCredentials { PublicKeySerialized = publicKeyBytes, @@ -115,7 +109,7 @@ public static EncryptionCredentials LoadFromFile(string path) }; } - private class SerializedPair + class SerializedPair { public string PublicKeyFingerprint; public string PublicKey; diff --git a/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs index 1357dce9b..30a0f539b 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs @@ -12,28 +12,27 @@ public class EncryptionTransport : Transport, PortTransport { public override bool IsEncrypted => true; public override string EncryptionCipher => "AES256-GCM"; - public Transport inner; + [FormerlySerializedAs("inner")] + public Transport Inner; public ushort Port { get { - if (inner is PortTransport portTransport) - { + if (Inner is PortTransport portTransport) return portTransport.Port; - } - Debug.LogError($"EncryptionTransport can't get Port because {inner} is not a PortTransport"); + Debug.LogError($"EncryptionTransport can't get Port because {Inner} is not a PortTransport"); return 0; } set { - if (inner is PortTransport portTransport) + if (Inner is PortTransport portTransport) { portTransport.Port = value; return; } - Debug.LogError($"EncryptionTransport can't set Port because {inner} is not a PortTransport"); + Debug.LogError($"EncryptionTransport can't set Port because {Inner} is not a PortTransport"); } } @@ -41,81 +40,78 @@ public enum ValidationMode { Off, List, - Callback, + Callback } - public ValidationMode clientValidateServerPubKey; + [FormerlySerializedAs("clientValidateServerPubKey")] + public ValidationMode ClientValidateServerPubKey; + [FormerlySerializedAs("clientTrustedPubKeySignatures")] [Tooltip("List of public key fingerprints the client will accept")] - public string[] clientTrustedPubKeySignatures; - public Func onClientValidateServerPubKey; - public bool serverLoadKeyPairFromFile; - public string serverKeypairPath = "./server-keys.json"; + public string[] ClientTrustedPubKeySignatures; + public Func OnClientValidateServerPubKey; + [FormerlySerializedAs("serverLoadKeyPairFromFile")] + public bool ServerLoadKeyPairFromFile; + [FormerlySerializedAs("serverKeypairPath")] + public string ServerKeypairPath = "./server-keys.json"; - private EncryptedConnection _client; + EncryptedConnection client; - private Dictionary _serverConnections = new Dictionary(); + readonly Dictionary serverConnections = new Dictionary(); - private List _serverPendingConnections = + readonly List serverPendingConnections = new List(); - private EncryptionCredentials _credentials; - public string EncryptionPublicKeyFingerprint => _credentials?.PublicKeyFingerprint; - public byte[] EncryptionPublicKey => _credentials?.PublicKeySerialized; + EncryptionCredentials credentials; + public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint; + public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized; - private void ServerRemoveFromPending(EncryptedConnection con) + void ServerRemoveFromPending(EncryptedConnection con) { - for (int i = 0; i < _serverPendingConnections.Count; i++) - { - if (_serverPendingConnections[i] == con) + for (int i = 0; i < serverPendingConnections.Count; i++) + if (serverPendingConnections[i] == con) { // remove by swapping with last - int lastIndex = _serverPendingConnections.Count - 1; - _serverPendingConnections[i] = _serverPendingConnections[lastIndex]; - _serverPendingConnections.RemoveAt(lastIndex); + int lastIndex = serverPendingConnections.Count - 1; + serverPendingConnections[i] = serverPendingConnections[lastIndex]; + serverPendingConnections.RemoveAt(lastIndex); break; } - } } - private void HandleInnerServerDisconnected(int connId) + void HandleInnerServerDisconnected(int connId) { - if (_serverConnections.TryGetValue(connId, out EncryptedConnection con)) + if (serverConnections.TryGetValue(connId, out EncryptedConnection con)) { ServerRemoveFromPending(con); - _serverConnections.Remove(connId); + serverConnections.Remove(connId); } OnServerDisconnected?.Invoke(connId); } - private void HandleInnerServerError(int connId, TransportError type, string msg) - { - OnServerError?.Invoke(connId, type, $"inner: {msg}"); - } + void HandleInnerServerError(int connId, TransportError type, string msg) => OnServerError?.Invoke(connId, type, $"inner: {msg}"); - private void HandleInnerServerDataReceived(int connId, ArraySegment data, int channel) + void HandleInnerServerDataReceived(int connId, ArraySegment data, int channel) { - if (_serverConnections.TryGetValue(connId, out EncryptedConnection c)) - { + if (serverConnections.TryGetValue(connId, out EncryptedConnection c)) c.OnReceiveRaw(data, channel); - } } - private void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, inner.ServerGetClientAddress(connId)); + void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId)); - private void HandleInnerServerConnected(int connId, string clientRemoteAddress) + void HandleInnerServerConnected(int connId, string clientRemoteAddress) { - Debug.Log($"[EncryptionTransport] New connection #{connId}"); + Debug.Log($"[EncryptionTransport] New connection #{connId} from {clientRemoteAddress}"); EncryptedConnection ec = null; ec = new EncryptedConnection( - _credentials, + credentials, false, - (segment, channel) => inner.ServerSend(connId, segment, channel), + (segment, channel) => Inner.ServerSend(connId, segment, channel), (segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel), () => { Debug.Log($"[EncryptionTransport] Connection #{connId} is ready"); + // ReSharper disable once AccessToModifiedClosure ServerRemoveFromPending(ec); - //OnServerConnected?.Invoke(connId); OnServerConnectedWithAddress?.Invoke(connId, clientRemoteAddress); }, (type, msg) => @@ -123,32 +119,25 @@ private void HandleInnerServerConnected(int connId, string clientRemoteAddress) OnServerError?.Invoke(connId, type, msg); ServerDisconnect(connId); }); - _serverConnections.Add(connId, ec); - _serverPendingConnections.Add(ec); + serverConnections.Add(connId, ec); + serverPendingConnections.Add(ec); } - private void HandleInnerClientDisconnected() + void HandleInnerClientDisconnected() { - _client = null; + client = null; OnClientDisconnected?.Invoke(); } - private void HandleInnerClientError(TransportError arg1, string arg2) - { - OnClientError?.Invoke(arg1, $"inner: {arg2}"); - } + void HandleInnerClientError(TransportError arg1, string arg2) => OnClientError?.Invoke(arg1, $"inner: {arg2}"); - private void HandleInnerClientDataReceived(ArraySegment data, int channel) - { - _client?.OnReceiveRaw(data, channel); - } + void HandleInnerClientDataReceived(ArraySegment data, int channel) => client?.OnReceiveRaw(data, channel); - private void HandleInnerClientConnected() - { - _client = new EncryptedConnection( - _credentials, + void HandleInnerClientConnected() => + client = new EncryptedConnection( + credentials, true, - (segment, channel) => inner.ClientSend(segment, channel), + (segment, channel) => Inner.ClientSend(segment, channel), (segment, channel) => OnClientDataReceived?.Invoke(segment, channel), () => { @@ -160,25 +149,23 @@ private void HandleInnerClientConnected() ClientDisconnect(); }, HandleClientValidateServerPubKey); - } - private bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) + bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) { - switch (clientValidateServerPubKey) + switch (ClientValidateServerPubKey) { case ValidationMode.Off: return true; case ValidationMode.List: - return Array.IndexOf(clientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; + return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; case ValidationMode.Callback: - return onClientValidateServerPubKey(pubKeyInfo); + return OnClientValidateServerPubKey(pubKeyInfo); default: throw new ArgumentOutOfRangeException(); } } - void Awake() - { + void Awake() => // check if encryption via hardware acceleration is supported. // this can be useful to know for low end devices. // @@ -188,27 +175,26 @@ void Awake() // https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs // which Unity does not support yet. Debug.Log($"EncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}"); - } - public override bool Available() => inner.Available(); + public override bool Available() => Inner.Available(); - public override bool ClientConnected() => _client != null && _client.IsReady; + public override bool ClientConnected() => client != null && client.IsReady; public override void ClientConnect(string address) { - switch (clientValidateServerPubKey) + switch (ClientValidateServerPubKey) { case ValidationMode.Off: break; case ValidationMode.List: - if (clientTrustedPubKeySignatures == null || clientTrustedPubKeySignatures.Length == 0) + if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0) { OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty."); return; } break; case ValidationMode.Callback: - if (onClientValidateServerPubKey == null) + if (OnClientValidateServerPubKey == null) { OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set"); return; @@ -217,95 +203,79 @@ public override void ClientConnect(string address) default: throw new ArgumentOutOfRangeException(); } - _credentials = EncryptionCredentials.Generate(); - inner.OnClientConnected = HandleInnerClientConnected; - inner.OnClientDataReceived = HandleInnerClientDataReceived; - inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel); - inner.OnClientError = HandleInnerClientError; - inner.OnClientDisconnected = HandleInnerClientDisconnected; - inner.ClientConnect(address); + credentials = EncryptionCredentials.Generate(); + Inner.OnClientConnected = HandleInnerClientConnected; + Inner.OnClientDataReceived = HandleInnerClientDataReceived; + Inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel); + Inner.OnClientError = HandleInnerClientError; + Inner.OnClientDisconnected = HandleInnerClientDisconnected; + Inner.ClientConnect(address); } public override void ClientSend(ArraySegment segment, int channelId = Channels.Reliable) => - _client?.Send(segment, channelId); + client?.Send(segment, channelId); - public override void ClientDisconnect() => inner.ClientDisconnect(); + public override void ClientDisconnect() => Inner.ClientDisconnect(); - public override Uri ServerUri() => inner.ServerUri(); + public override Uri ServerUri() => Inner.ServerUri(); - public override bool ServerActive() => inner.ServerActive(); + public override bool ServerActive() => Inner.ServerActive(); public override void ServerStart() { - if (serverLoadKeyPairFromFile) - { - _credentials = EncryptionCredentials.LoadFromFile(serverKeypairPath); - } + if (ServerLoadKeyPairFromFile) + credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath); else - { - _credentials = EncryptionCredentials.Generate(); - } + credentials = EncryptionCredentials.Generate(); #pragma warning disable CS0618 // Type or member is obsolete - inner.OnServerConnected = HandleInnerServerConnected; + Inner.OnServerConnected = HandleInnerServerConnected; #pragma warning restore CS0618 // Type or member is obsolete - inner.OnServerConnectedWithAddress = HandleInnerServerConnected; - inner.OnServerDataReceived = HandleInnerServerDataReceived; - inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel); - inner.OnServerError = HandleInnerServerError; - inner.OnServerDisconnected = HandleInnerServerDisconnected; - inner.ServerStart(); + Inner.OnServerConnectedWithAddress = HandleInnerServerConnected; + Inner.OnServerDataReceived = HandleInnerServerDataReceived; + Inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel); + Inner.OnServerError = HandleInnerServerError; + Inner.OnServerDisconnected = HandleInnerServerDisconnected; + Inner.ServerStart(); } public override void ServerSend(int connectionId, ArraySegment segment, int channelId = Channels.Reliable) { - if (_serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) - { + if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) connection.Send(segment, channelId); - } } - public override void ServerDisconnect(int connectionId) - { + public override void ServerDisconnect(int connectionId) => // cleanup is done via inners disconnect event - inner.ServerDisconnect(connectionId); - } + Inner.ServerDisconnect(connectionId); - public override string ServerGetClientAddress(int connectionId) => inner.ServerGetClientAddress(connectionId); + public override string ServerGetClientAddress(int connectionId) => Inner.ServerGetClientAddress(connectionId); - public override void ServerStop() => inner.ServerStop(); + public override void ServerStop() => Inner.ServerStop(); public override int GetMaxPacketSize(int channelId = Channels.Reliable) => - inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; + Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; - public override void Shutdown() => inner.Shutdown(); + public override void Shutdown() => Inner.Shutdown(); - public override void ClientEarlyUpdate() - { - inner.ClientEarlyUpdate(); - } + public override void ClientEarlyUpdate() => Inner.ClientEarlyUpdate(); public override void ClientLateUpdate() { - inner.ClientLateUpdate(); + Inner.ClientLateUpdate(); Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); - _client?.TickNonReady(NetworkTime.localTime); + client?.TickNonReady(NetworkTime.localTime); Profiler.EndSample(); } - public override void ServerEarlyUpdate() - { - inner.ServerEarlyUpdate(); - } + public override void ServerEarlyUpdate() => Inner.ServerEarlyUpdate(); public override void ServerLateUpdate() { - inner.ServerLateUpdate(); + Inner.ServerLateUpdate(); Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); // Reverse iteration as entries can be removed while updating - for (int i = _serverPendingConnections.Count - 1; i >= 0; i--) - { - _serverPendingConnections[i].TickNonReady(NetworkTime.time); - } + for (int i = serverPendingConnections.Count - 1; i >= 0; i--) + serverPendingConnections[i].TickNonReady(NetworkTime.time); Profiler.EndSample(); } } From c5824f39e18733937f20cdfe8f03c31b9606411b Mon Sep 17 00:00:00 2001 From: Robin Rolf Date: Sat, 26 Oct 2024 17:58:24 +0000 Subject: [PATCH 2/2] feat: ThreadedEncryptionTransport --- .../Editor/EncryptionTransportInspector.cs | 3 + .../Encryption/EncryptedConnection.cs | 58 ++-- .../Encryption/ThreadedEncryptionTransport.cs | 314 ++++++++++++++++++ .../ThreadedEncryptionTransport.cs.meta | 11 + 4 files changed, 360 insertions(+), 26 deletions(-) create mode 100644 Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs create mode 100644 Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs index cb50bf2b3..6c39f909f 100644 --- a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs @@ -77,5 +77,8 @@ public override void OnInspectorGUI() serializedObject.ApplyModifiedProperties(); } + + [CustomEditor(typeof(ThreadedEncryptionTransport), true)] + class EncryptionThreadedTransportInspector : EncryptionTransportInspector {} } } diff --git a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs index 24a52f3ca..f2a18ca63 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs @@ -1,6 +1,7 @@ using System; using System.Security.Cryptography; using System.Text; +using System.Threading; using Mirror.BouncyCastle.Crypto; using Mirror.BouncyCastle.Crypto.Agreement; using Mirror.BouncyCastle.Crypto.Digests; @@ -48,20 +49,19 @@ public class EncryptedConnection // Set up a global cipher instance, it is initialised/reset before use // (AesFastEngine used to exist, but was removed due to side channel issues) // use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core) - static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); + static readonly ThreadLocal Cipher = new ThreadLocal(() => new GcmBlockCipher(AesUtilities.CreateEngine())); // Set up a global HKDF with a SHA-256 digest - static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest()); + static readonly ThreadLocal Hkdf = new ThreadLocal(() => new HkdfBytesGenerator(new Sha256Digest())); // Global byte array to store nonce sent by the remote side, they're used immediately after - static readonly byte[] ReceiveNonce = new byte[NonceSize]; + static readonly ThreadLocal ReceiveNonce = new ThreadLocal(() => new byte[NonceSize]); // Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes* - static readonly byte[] TMPRemoteSaltBuffer = new byte[HkdfSaltSize]; + static readonly ThreadLocal TMPRemoteSaltBuffer = new ThreadLocal(() => new byte[HkdfSaltSize]); + // buffer for encrypt/decrypt operations, resized larger as needed - // this is also the buffer that will be returned to mirror via ArraySegment - // so any thread safety concerns would need to take extra care here - static byte[] TMPCryptBuffer = new byte[2048]; + static ThreadLocal TMPCryptBuffer = new ThreadLocal(() => new byte[2048]); // packet headers enum OpCodes : byte @@ -189,7 +189,7 @@ public void OnReceiveRaw(ArraySegment data, int channel) } ArraySegment ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize); - reader.ReadBytes(ReceiveNonce, NonceSize); + reader.ReadBytes(ReceiveNonce.Value, NonceSize); Profiler.BeginSample("EncryptedConnection.Decrypt"); ArraySegment plaintext = Decrypt(ciphertext); @@ -233,8 +233,8 @@ public void OnReceiveRaw(ArraySegment data, int channel) state = State.WaitingHandshakeReply; ResetTimeouts(); - reader.ReadBytes(TMPRemoteSaltBuffer, HkdfSaltSize); - CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer); + reader.ReadBytes(TMPRemoteSaltBuffer.Value, HkdfSaltSize); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer.Value); SendHandshakeFin(); break; case OpCodes.HandshakeFin: @@ -307,25 +307,28 @@ ArraySegment Encrypt(ArraySegment plaintext) // Need to make the nonce unique again before encrypting another message UpdateNonce(); // Re-initialize the cipher with our cached parameters - Cipher.Init(true, cipherParametersEncrypt); + Cipher.Value.Init(true, cipherParametersEncrypt); // Calculate the expected output size, this should always be input size + mac size - int outSize = Cipher.GetOutputSize(plaintext.Count); + int outSize = Cipher.Value.GetOutputSize(plaintext.Count); #if UNITY_EDITOR // expecting the outSize to be input size + MacSize if (outSize != plaintext.Count + MacSizeBytes) throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}"); #endif // Resize the static buffer to fit - EnsureSize(ref TMPCryptBuffer, outSize); + byte[] cryptBuffer = TMPCryptBuffer.Value; + EnsureSize(ref cryptBuffer, outSize); + TMPCryptBuffer.Value = cryptBuffer; + int resultLen; try { // Run the plain text through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, TMPCryptBuffer, 0); + Cipher.Value.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, cryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac) - resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); + resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types // @@ -339,7 +342,7 @@ ArraySegment Encrypt(ArraySegment plaintext) if (resultLen != outSize) throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); #endif - return new ArraySegment(TMPCryptBuffer, 0, resultLen); + return new ArraySegment(cryptBuffer, 0, resultLen); } ArraySegment Decrypt(ArraySegment ciphertext) @@ -351,25 +354,28 @@ ArraySegment Decrypt(ArraySegment ciphertext) return new ArraySegment(); } // Re-initialize the cipher with our cached parameters - Cipher.Init(false, cipherParametersDecrypt); + Cipher.Value.Init(false, cipherParametersDecrypt); // Calculate the expected output size, this should always be input size - mac size - int outSize = Cipher.GetOutputSize(ciphertext.Count); + int outSize = Cipher.Value.GetOutputSize(ciphertext.Count); #if UNITY_EDITOR // expecting the outSize to be input size - MacSize if (outSize != ciphertext.Count - MacSizeBytes) throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}"); #endif - // Resize the static buffer to fit - EnsureSize(ref TMPCryptBuffer, outSize); + + byte[] cryptBuffer = TMPCryptBuffer.Value; + EnsureSize(ref cryptBuffer, outSize); + TMPCryptBuffer.Value = cryptBuffer; + int resultLen; try { // Run the ciphertext through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, TMPCryptBuffer, 0); + Cipher.Value.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, cryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac) - resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); + resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types catch (Exception e) @@ -382,7 +388,7 @@ ArraySegment Decrypt(ArraySegment ciphertext) if (resultLen != outSize) throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); #endif - return new ArraySegment(TMPCryptBuffer, 0, resultLen); + return new ArraySegment(cryptBuffer, 0, resultLen); } void UpdateNonce() @@ -477,13 +483,13 @@ void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) return; } - Hkdf.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo)); + Hkdf.Value.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo)); // Allocate a buffer for the output key byte[] keyRaw = new byte[KeyLength]; // Generate the output keying material - Hkdf.GenerateBytes(keyRaw, 0, keyRaw.Length); + Hkdf.Value.GenerateBytes(keyRaw, 0, keyRaw.Length); KeyParameter key = new KeyParameter(keyRaw); @@ -493,7 +499,7 @@ void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) // we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance // instead of creating a new one each encrypt/decrypt cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, nonce); - cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); + cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce.Value); } /** diff --git a/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs new file mode 100644 index 000000000..0da64b528 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs @@ -0,0 +1,314 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using Mirror.BouncyCastle.Crypto; +using UnityEngine; +using UnityEngine.Profiling; +using UnityEngine.Serialization; +using Debug = UnityEngine.Debug; + +namespace Mirror.Transports.Encryption +{ + [HelpURL("https://mirror-networking.gitbook.io/docs/manual/transports/encryption-transport")] + public class ThreadedEncryptionTransport : ThreadedTransport, PortTransport + { + public override bool IsEncrypted => true; + public override string EncryptionCipher => "AES256-GCM"; + [FormerlySerializedAs("inner")] + public ThreadedTransport Inner; + + public ushort Port + { + get + { + if (Inner is PortTransport portTransport) + return portTransport.Port; + + Debug.LogError($"ThreadedEncryptionTransport can't get Port because {Inner} is not a PortTransport"); + return 0; + } + set + { + if (Inner is PortTransport portTransport) + { + portTransport.Port = value; + return; + } + Debug.LogError($"ThreadedEncryptionTransport can't set Port because {Inner} is not a PortTransport"); + } + } + + public enum ValidationMode + { + Off, + List, + Callback + } + + [FormerlySerializedAs("clientValidateServerPubKey")] + public ValidationMode ClientValidateServerPubKey; + [FormerlySerializedAs("clientTrustedPubKeySignatures")] + [Tooltip("List of public key fingerprints the client will accept")] + public string[] ClientTrustedPubKeySignatures; + /// + /// Called when a client connects to a server + /// ATTENTION: NOT THREAD SAFE. + /// This will be called on the worker thread. + /// + public Func OnClientValidateServerPubKey; + [FormerlySerializedAs("serverLoadKeyPairFromFile")] + public bool ServerLoadKeyPairFromFile; + [FormerlySerializedAs("serverKeypairPath")] + public string ServerKeypairPath = "./server-keys.json"; + + EncryptedConnection client; + + readonly Dictionary serverConnections = new Dictionary(); + + readonly List serverPendingConnections = + new List(); + + EncryptionCredentials credentials; + public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint; + public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized; + + // Used for threaded time keeping as unitys Time.time is not thread safe + Stopwatch stopwatch = Stopwatch.StartNew(); + + void ServerRemoveFromPending(EncryptedConnection con) + { + for (int i = 0; i < serverPendingConnections.Count; i++) + if (serverPendingConnections[i] == con) + { + // remove by swapping with last + int lastIndex = serverPendingConnections.Count - 1; + serverPendingConnections[i] = serverPendingConnections[lastIndex]; + serverPendingConnections.RemoveAt(lastIndex); + break; + } + } + + void HandleInnerServerDisconnected(int connId) + { + if (serverConnections.TryGetValue(connId, out EncryptedConnection con)) + { + ServerRemoveFromPending(con); + serverConnections.Remove(connId); + } + OnThreadedServerDisconnected(connId); + } + + void HandleInnerServerError(int connId, TransportError type, string msg) => OnThreadedServerError(connId, type, $"inner: {msg}"); + + void HandleInnerServerDataReceived(int connId, ArraySegment data, int channel) + { + if (serverConnections.TryGetValue(connId, out EncryptedConnection c)) + c.OnReceiveRaw(data, channel); + } + + void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId)); + + void HandleInnerServerConnected(int connId, string clientRemoteAddress) + { + Debug.Log($"[ThreadedEncryptionTransport] New connection #{connId} from {clientRemoteAddress}"); + EncryptedConnection ec = null; + ec = new EncryptedConnection( + credentials, + false, + (segment, channel) => Inner.ServerSend(connId, segment, channel), + (segment, channel) => OnThreadedServerReceive(connId, segment, channel), + () => + { + Debug.Log($"[ThreadedEncryptionTransport] Connection #{connId} is ready"); + // ReSharper disable once AccessToModifiedClosure + ServerRemoveFromPending(ec); + OnThreadedServerConnected(connId, new IPEndPoint(IPAddress.Parse(clientRemoteAddress), 0)); + }, + (type, msg) => + { + OnThreadedServerError(connId, type, msg); + ServerDisconnect(connId); + }); + serverConnections.Add(connId, ec); + serverPendingConnections.Add(ec); + } + + void HandleInnerClientDisconnected() + { + client = null; + OnThreadedClientDisconnected(); + } + + void HandleInnerClientError(TransportError arg1, string arg2) => OnThreadedClientError(arg1, $"inner: {arg2}"); + + void HandleInnerClientDataReceived(ArraySegment data, int channel) => client?.OnReceiveRaw(data, channel); + + void HandleInnerClientConnected() => + client = new EncryptedConnection( + credentials, + true, + (segment, channel) => Inner.ClientSend(segment, channel), + (segment, channel) => OnThreadedClientReceive(segment, channel), + () => + { + OnThreadedClientConnected(); + }, + (type, msg) => + { + OnThreadedClientError(type, msg); + ClientDisconnect(); + }, + HandleClientValidateServerPubKey); + + bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) + { + switch (ClientValidateServerPubKey) + { + case ValidationMode.Off: + return true; + case ValidationMode.List: + return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; + case ValidationMode.Callback: + return OnClientValidateServerPubKey(pubKeyInfo); + default: + throw new ArgumentOutOfRangeException(); + } + } + + protected override void Awake() + { + base.Awake(); + // check if encryption via hardware acceleration is supported. + // this can be useful to know for low end devices. + // + // hardware acceleration requires netcoreapp3.0 or later: + // https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/AesUtilities.cs#L18 + // because AesEngine_x86 requires System.Runtime.Intrinsics.X86: + // https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs + // which Unity does not support yet. + Debug.Log($"ThreadedEncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}"); + } + + public override bool Available() => Inner.Available(); + + protected override void ThreadedClientConnect(string address) + { + switch (ClientValidateServerPubKey) + { + case ValidationMode.Off: + break; + case ValidationMode.List: + if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0) + { + OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty."); + return; + } + break; + case ValidationMode.Callback: + if (OnClientValidateServerPubKey == null) + { + OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set"); + return; + } + break; + default: + throw new ArgumentOutOfRangeException(); + } + credentials = EncryptionCredentials.Generate(); + Inner.OnClientConnected = HandleInnerClientConnected; + Inner.OnClientDataReceived = HandleInnerClientDataReceived; + Inner.OnClientDataSent = (bytes, channel) => OnThreadedClientSend(bytes, channel); + Inner.OnClientError = HandleInnerClientError; + Inner.OnClientDisconnected = HandleInnerClientDisconnected; + Inner.ClientConnect(address); + } + + protected override void ThreadedClientConnect(Uri address) => Inner.ClientConnect(address); + + protected override void ThreadedClientSend(ArraySegment segment, int channelId) => + client?.Send(segment, channelId); + + protected override void ThreadedClientDisconnect() => Inner.ClientDisconnect(); + + protected override void ThreadedServerStart() + { + if (ServerLoadKeyPairFromFile) + credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath); + else + credentials = EncryptionCredentials.Generate(); +#pragma warning disable CS0618 // Type or member is obsolete + Inner.OnServerConnected = HandleInnerServerConnected; +#pragma warning restore CS0618 // Type or member is obsolete + Inner.OnServerConnectedWithAddress = HandleInnerServerConnected; + Inner.OnServerDataReceived = HandleInnerServerDataReceived; + Inner.OnServerDataSent = (connId, bytes, channel) => OnThreadedServerSend(connId, bytes, channel); + Inner.OnServerError = HandleInnerServerError; + Inner.OnServerDisconnected = HandleInnerServerDisconnected; + Inner.ServerStart(); + } + + protected override void ThreadedServerSend(int connectionId, ArraySegment segment, int channelId) + { + if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) + connection.Send(segment, channelId); + } + + protected override void ThreadedServerDisconnect(int connectionId) => + // cleanup is done via inners disconnect event + Inner.ServerDisconnect(connectionId); + + protected override void ThreadedClientEarlyUpdate() {} + + protected override void ThreadedServerStop() => Inner.ServerStop(); + + public override Uri ServerUri() => Inner.ServerUri(); + + public override int GetMaxPacketSize(int channelId = Channels.Reliable) => + Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; + + protected override void ThreadedShutdown() => Inner.Shutdown(); + + public override void ClientEarlyUpdate() + { + base.ClientEarlyUpdate(); + Inner.ClientEarlyUpdate(); + } + + public override void ClientLateUpdate() + { + base.ClientLateUpdate(); + Inner.ClientLateUpdate(); + } + + protected override void ThreadedClientLateUpdate() + { + Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate"); + client?.TickNonReady(stopwatch.Elapsed.TotalSeconds); + Profiler.EndSample(); + } + + protected override void ThreadedServerEarlyUpdate() {} + + public override void ServerEarlyUpdate() + { + base.ServerEarlyUpdate(); + Inner.ServerEarlyUpdate(); + } + + public override void ServerLateUpdate() + { + base.ServerLateUpdate(); + Inner.ServerLateUpdate(); + } + + protected override void ThreadedServerLateUpdate() + { + Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate"); + // Reverse iteration as entries can be removed while updating + for (int i = serverPendingConnections.Count - 1; i >= 0; i--) + serverPendingConnections[i].TickNonReady(stopwatch.Elapsed.TotalSeconds); + Profiler.EndSample(); + } + } +} diff --git a/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta new file mode 100644 index 000000000..44448f3ff --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/ThreadedEncryptionTransport.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5d3e310924fb49c195391b9699f20809 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3} + userData: + assetBundleName: + assetBundleVariant: