diff --git a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs index 269ce1b15..1fa25ffc5 100644 --- a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs +++ b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs @@ -22,7 +22,7 @@ public void Setup() GameObject gameObject = new GameObject(); encryption = gameObject.AddComponent(); - encryption.inner = inner; + encryption.Inner = inner; } [TearDown] diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs index 24557f910..cb50bf2b3 100644 --- a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs @@ -17,11 +17,11 @@ public class EncryptionTransportInspector : UnityEditor.Editor void OnEnable() { - innerProperty = serializedObject.FindProperty("inner"); - clientValidatesServerPubKeyProperty = serializedObject.FindProperty("clientValidateServerPubKey"); - clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("clientTrustedPubKeySignatures"); - serverKeypairPathProperty = serializedObject.FindProperty("serverKeypairPath"); - serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("serverLoadKeyPairFromFile"); + innerProperty = serializedObject.FindProperty("Inner"); + clientValidatesServerPubKeyProperty = serializedObject.FindProperty("ClientValidateServerPubKey"); + clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("ClientTrustedPubKeySignatures"); + serverKeypairPathProperty = serializedObject.FindProperty("ServerKeypairPath"); + serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("ServerLoadKeyPairFromFile"); } public override void OnInspectorGUI() diff --git a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs index 46c916310..24a52f3ca 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs @@ -14,32 +14,32 @@ namespace Mirror.Transports.Encryption public class EncryptedConnection { // 256-bit key - private const int KeyLength = 32; + const int KeyLength = 32; // 512-bit salt for the key derivation function - private const int HkdfSaltSize = KeyLength * 2; + const int HkdfSaltSize = KeyLength * 2; // Info tag for the HKDF, this just adds more entropy - private static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport"); + static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport"); // fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed) - private const int NonceSize = 12; + const int NonceSize = 12; // this is the size of the "checksum" included in each encrypted payload // 16 bytes/128 bytes is the recommended value for best security // can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker. // Setting it lower than 12 bytes is not recommended - private const int MacSizeBytes = 16; + const int MacSizeBytes = 16; - private const int MacSizeBits = MacSizeBytes * 8; + const int MacSizeBits = MacSizeBytes * 8; // How much metadata overhead we have for regular packets public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize; // After how many seconds of not receiving a handshake packet we should time out - private const double DurationTimeout = 2; // 2s + const double DurationTimeout = 2; // 2s // After how many seconds to assume the last handshake packet got lost and to resend another one - private const double DurationResend = 0.05; // 50ms + const double DurationResend = 0.05; // 50ms // Static fields for allocation efficiency, makes this not thread safe @@ -48,20 +48,20 @@ public class EncryptedConnection // Set up a global cipher instance, it is initialised/reset before use // (AesFastEngine used to exist, but was removed due to side channel issues) // use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core) - private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); + static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); // Set up a global HKDF with a SHA-256 digest - private static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest()); + static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest()); // Global byte array to store nonce sent by the remote side, they're used immediately after - private static readonly byte[] ReceiveNonce = new byte[NonceSize]; + static readonly byte[] ReceiveNonce = new byte[NonceSize]; // Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes* - private static byte[] _tmpRemoteSaltBuffer = new byte[HkdfSaltSize]; + static readonly byte[] TMPRemoteSaltBuffer = new byte[HkdfSaltSize]; // buffer for encrypt/decrypt operations, resized larger as needed // this is also the buffer that will be returned to mirror via ArraySegment // so any thread safety concerns would need to take extra care here - private static byte[] _tmpCryptBuffer = new byte[2048]; + static byte[] TMPCryptBuffer = new byte[2048]; // packet headers enum OpCodes : byte @@ -70,7 +70,7 @@ enum OpCodes : byte Data = 1, HandshakeStart = 2, HandshakeAck = 3, - HandshakeFin = 4, + HandshakeFin = 4 } enum State @@ -91,37 +91,37 @@ enum State Ready } - private State _state = State.WaitingHandshake; + State state = State.WaitingHandshake; // Key exchange confirmed and data can be sent freely - public bool IsReady => _state == State.Ready; + public bool IsReady => state == State.Ready; // Callback to send off encrypted data - private Action, int> _send; + readonly Action, int> send; // Callback when received data has been decrypted - private Action, int> _receive; + readonly Action, int> receive; // Callback when the connection becomes ready - private Action _ready; + readonly Action ready; // On-error callback, disconnect expected - private Action _error; + readonly Action error; // Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance // (usually client validates the server key) - private Func _validateRemoteKey; + readonly Func validateRemoteKey; // Our asymmetric credentials for the initial DH exchange - private EncryptionCredentials _credentials; - private byte[] _hkdfSalt; + EncryptionCredentials credentials; + readonly byte[] hkdfSalt; // After no handshake packet in this many seconds, the handshake fails - private double _handshakeTimeout; + double handshakeTimeout; // When to assume the last handshake packet got lost and to resend another one - private double _nextHandshakeResend; + double nextHandshakeResend; // we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in // so we can update it without creating a new AeadParameters instance // this might break in the future! (will cause bad data) - private byte[] _nonce = new byte[NonceSize]; - private AeadParameters _cipherParametersEncrypt; - private AeadParameters _cipherParametersDecrypt; + byte[] nonce = new byte[NonceSize]; + AeadParameters cipherParametersEncrypt; + AeadParameters cipherParametersDecrypt; /* @@ -130,7 +130,7 @@ enum State * * The client does this, since the fin is not acked explicitly, but by receiving data to decrypt */ - private readonly bool _sendsFirst; + readonly bool sendsFirst; public EncryptedConnection(EncryptionCredentials credentials, bool isClient, @@ -140,28 +140,24 @@ public EncryptedConnection(EncryptionCredentials credentials, Action errorAction, Func validateRemoteKey = null) { - _credentials = credentials; - _sendsFirst = isClient; - if (!_sendsFirst) - { + this.credentials = credentials; + sendsFirst = isClient; + if (!sendsFirst) // salt is controlled by the server - _hkdfSalt = GenerateSecureBytes(HkdfSaltSize); - } - _send = sendAction; - _receive = receiveAction; - _ready = readyAction; - _error = errorAction; - _validateRemoteKey = validateRemoteKey; + hkdfSalt = GenerateSecureBytes(HkdfSaltSize); + send = sendAction; + receive = receiveAction; + ready = readyAction; + error = errorAction; + this.validateRemoteKey = validateRemoteKey; } // Generates a random starting nonce - private static byte[] GenerateSecureBytes(int size) + static byte[] GenerateSecureBytes(int size) { byte[] bytes = new byte[size]; using (RandomNumberGenerator rng = RandomNumberGenerator.Create()) - { rng.GetBytes(bytes); - } return bytes; } @@ -170,7 +166,7 @@ public void OnReceiveRaw(ArraySegment data, int channel) { if (data.Count < 1) { - _error(TransportError.Unexpected, "Received empty packet"); + error(TransportError.Unexpected, "Received empty packet"); return; } @@ -181,18 +177,14 @@ public void OnReceiveRaw(ArraySegment data, int channel) { case OpCodes.Data: // first sender ready is implicit when data is received - if (_sendsFirst && _state == State.WaitingHandshakeReply) - { + if (sendsFirst && state == State.WaitingHandshakeReply) SetReady(); - } else if (!IsReady) - { - _error(TransportError.Unexpected, "Unexpected data while not ready."); - } + error(TransportError.Unexpected, "Unexpected data while not ready."); if (reader.Remaining < Overhead) { - _error(TransportError.Unexpected, "received data packet smaller than metadata size"); + error(TransportError.Unexpected, "received data packet smaller than metadata size"); return; } @@ -203,72 +195,62 @@ public void OnReceiveRaw(ArraySegment data, int channel) ArraySegment plaintext = Decrypt(ciphertext); Profiler.EndSample(); if (plaintext.Count == 0) - { // error return; - } - _receive(plaintext, channel); + receive(plaintext, channel); break; case OpCodes.HandshakeStart: - if (_sendsFirst) + if (sendsFirst) { - _error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this."); + error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this."); return; } - if (_state == State.WaitingHandshakeReply) - { + if (state == State.WaitingHandshakeReply) // this is fine, packets may arrive out of order return; - } - _state = State.WaitingHandshakeReply; + state = State.WaitingHandshakeReply; ResetTimeouts(); - CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _hkdfSalt); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining), hkdfSalt); SendHandshakeAndPubKey(OpCodes.HandshakeAck); break; case OpCodes.HandshakeAck: - if (!_sendsFirst) + if (!sendsFirst) { - _error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this."); + error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this."); return; } if (IsReady) - { // this is fine, packets may arrive out of order return; - } - if (_state == State.WaitingHandshakeReply) - { + if (state == State.WaitingHandshakeReply) // this is fine, packets may arrive out of order return; - } - _state = State.WaitingHandshakeReply; + state = State.WaitingHandshakeReply; ResetTimeouts(); - reader.ReadBytes(_tmpRemoteSaltBuffer, HkdfSaltSize); - CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _tmpRemoteSaltBuffer); + reader.ReadBytes(TMPRemoteSaltBuffer, HkdfSaltSize); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer); SendHandshakeFin(); break; case OpCodes.HandshakeFin: - if (_sendsFirst) + if (sendsFirst) { - _error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this."); + error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this."); return; } if (IsReady) - { // this is fine, packets may arrive out of order return; - } - if (_state != State.WaitingHandshakeReply) + if (state != State.WaitingHandshakeReply) { - _error(TransportError.Unexpected, + error(TransportError.Unexpected, "Received HandshakeFin packet, we didn't expect this yet."); return; } @@ -277,24 +259,25 @@ public void OnReceiveRaw(ArraySegment data, int channel) break; default: - _error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}"); + error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}"); break; } } } - private void SetReady() + + void SetReady() { // done with credentials, null out the reference - _credentials = null; + credentials = null; - _state = State.Ready; - _ready(); + state = State.Ready; + ready(); } - private void ResetTimeouts() + void ResetTimeouts() { - _handshakeTimeout = 0; - _nextHandshakeResend = -1; + handshakeTimeout = 0; + nextHandshakeResend = -1; } public void Send(ArraySegment data, int channel) @@ -307,161 +290,143 @@ public void Send(ArraySegment data, int channel) Profiler.EndSample(); if (encrypted.Count == 0) - { // error return; - } writer.WriteBytes(encrypted.Array, 0, encrypted.Count); // write nonce after since Encrypt will update it - writer.WriteBytes(_nonce, 0, NonceSize); - _send(writer.ToArraySegment(), channel); + writer.WriteBytes(nonce, 0, NonceSize); + send(writer.ToArraySegment(), channel); } } - private ArraySegment Encrypt(ArraySegment plaintext) + ArraySegment Encrypt(ArraySegment plaintext) { if (plaintext.Count == 0) - { // Invalid return new ArraySegment(); - } // Need to make the nonce unique again before encrypting another message UpdateNonce(); // Re-initialize the cipher with our cached parameters - Cipher.Init(true, _cipherParametersEncrypt); + Cipher.Init(true, cipherParametersEncrypt); // Calculate the expected output size, this should always be input size + mac size int outSize = Cipher.GetOutputSize(plaintext.Count); #if UNITY_EDITOR // expecting the outSize to be input size + MacSize if (outSize != plaintext.Count + MacSizeBytes) - { throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}"); - } #endif // Resize the static buffer to fit - EnsureSize(ref _tmpCryptBuffer, outSize); + EnsureSize(ref TMPCryptBuffer, outSize); int resultLen; try { // Run the plain text through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, _tmpCryptBuffer, 0); + Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, TMPCryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac) - resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); + resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types // catch (Exception e) { - _error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}"); + error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}"); return new ArraySegment(); } #if UNITY_EDITOR // expecting the result length to match the previously calculated input size + MacSize if (resultLen != outSize) - { throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); - } #endif - return new ArraySegment(_tmpCryptBuffer, 0, resultLen); + return new ArraySegment(TMPCryptBuffer, 0, resultLen); } - private ArraySegment Decrypt(ArraySegment ciphertext) + ArraySegment Decrypt(ArraySegment ciphertext) { if (ciphertext.Count <= MacSizeBytes) { - _error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})"); + error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})"); // Invalid return new ArraySegment(); } // Re-initialize the cipher with our cached parameters - Cipher.Init(false, _cipherParametersDecrypt); + Cipher.Init(false, cipherParametersDecrypt); // Calculate the expected output size, this should always be input size - mac size int outSize = Cipher.GetOutputSize(ciphertext.Count); #if UNITY_EDITOR // expecting the outSize to be input size - MacSize if (outSize != ciphertext.Count - MacSizeBytes) - { throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}"); - } #endif // Resize the static buffer to fit - EnsureSize(ref _tmpCryptBuffer, outSize); + EnsureSize(ref TMPCryptBuffer, outSize); int resultLen; try { // Run the ciphertext through the cipher, ProcessBytes will only process full blocks resultLen = - Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, _tmpCryptBuffer, 0); + Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, TMPCryptBuffer, 0); // Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac) - resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); + resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen); } // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types catch (Exception e) { - _error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data"); + error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data"); return new ArraySegment(); } #if UNITY_EDITOR // expecting the result length to match the previously calculated input size + MacSize if (resultLen != outSize) - { throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); - } #endif - return new ArraySegment(_tmpCryptBuffer, 0, resultLen); + return new ArraySegment(TMPCryptBuffer, 0, resultLen); } - private void UpdateNonce() + void UpdateNonce() { // increment the nonce by one // we need to ensure the nonce is *always* unique and not reused // easiest way to do this is by simply incrementing it for (int i = 0; i < NonceSize; i++) { - _nonce[i]++; - if (_nonce[i] != 0) - { + nonce[i]++; + if (nonce[i] != 0) break; - } } } - private static void EnsureSize(ref byte[] buffer, int size) + static void EnsureSize(ref byte[] buffer, int size) { if (buffer.Length < size) - { // double buffer to avoid constantly resizing by a few bytes Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2)); - } } - private void SendHandshakeAndPubKey(OpCodes opcode) + void SendHandshakeAndPubKey(OpCodes opcode) { using (NetworkWriterPooled writer = NetworkWriterPool.Get()) { writer.WriteByte((byte)opcode); if (opcode == OpCodes.HandshakeAck) - { - writer.WriteBytes(_hkdfSalt, 0, HkdfSaltSize); - } - writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length); - _send(writer.ToArraySegment(), Channels.Unreliable); + writer.WriteBytes(hkdfSalt, 0, HkdfSaltSize); + writer.WriteBytes(credentials.PublicKeySerialized, 0, credentials.PublicKeySerialized.Length); + send(writer.ToArraySegment(), Channels.Unreliable); } } - private void SendHandshakeFin() + void SendHandshakeFin() { using (NetworkWriterPooled writer = NetworkWriterPool.Get()) { writer.WriteByte((byte)OpCodes.HandshakeFin); - _send(writer.ToArraySegment(), Channels.Unreliable); + send(writer.ToArraySegment(), Channels.Unreliable); } } - private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) + void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) { AsymmetricKeyParameter remotePubKey; try @@ -470,11 +435,11 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) } catch (Exception e) { - _error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}"); + error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}"); return; } - if (_validateRemoteKey != null) + if (validateRemoteKey != null) { PubKeyInfo info = new PubKeyInfo { @@ -482,9 +447,9 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) Serialized = remotePubKeyRaw, Key = remotePubKey }; - if (!_validateRemoteKey(info)) + if (!validateRemoteKey(info)) { - _error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. "); + error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. "); return; } } @@ -493,7 +458,7 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) // This gives us the same key on the other side, with our public key and their remote // It's like magic, but with math! ECDHBasicAgreement ecdh = new ECDHBasicAgreement(); - ecdh.Init(_credentials.PrivateKey); + ecdh.Init(credentials.PrivateKey); byte[] sharedSecret; try { @@ -502,13 +467,13 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) catch (Exception e) { - _error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}"); + error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}"); return; } if (salt.Length != HkdfSaltSize) { - _error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}."); + error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}."); return; } @@ -523,12 +488,12 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) KeyParameter key = new KeyParameter(keyRaw); // generate a starting nonce - _nonce = GenerateSecureBytes(NonceSize); + nonce = GenerateSecureBytes(NonceSize); // we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance // instead of creating a new one each encrypt/decrypt - _cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, _nonce); - _cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); + cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, nonce); + cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); } /** @@ -537,53 +502,41 @@ private void CompleteExchange(ArraySegment remotePubKeyRaw, byte[] salt) public void TickNonReady(double time) { if (IsReady) + return; + + // Timeout reset + if (handshakeTimeout == 0) + handshakeTimeout = time + DurationTimeout; + else if (time > handshakeTimeout) { + error?.Invoke(TransportError.Timeout, $"Timed out during {state}, this probably just means the other side went away which is fine."); return; } // Timeout reset - if (_handshakeTimeout == 0) + if (nextHandshakeResend < 0) { - _handshakeTimeout = time + DurationTimeout; - } - else if (time > _handshakeTimeout) - { - _error?.Invoke(TransportError.Timeout, $"Timed out during {_state}, this probably just means the other side went away which is fine."); + nextHandshakeResend = time + DurationResend; return; } - // Timeout reset - if (_nextHandshakeResend < 0) - { - _nextHandshakeResend = time + DurationResend; - return; - } - - if (time < _nextHandshakeResend) - { + if (time < nextHandshakeResend) // Resend isn't due yet return; - } - _nextHandshakeResend = time + DurationResend; - switch (_state) + nextHandshakeResend = time + DurationResend; + switch (state) { case State.WaitingHandshake: - if (_sendsFirst) - { + if (sendsFirst) SendHandshakeAndPubKey(OpCodes.HandshakeStart); - } break; case State.WaitingHandshakeReply: - if (_sendsFirst) - { + if (sendsFirst) SendHandshakeFin(); - } else - { SendHandshakeAndPubKey(OpCodes.HandshakeAck); - } break; case State.Ready: // IsReady is checked above & early-returned diff --git a/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs index f12e4c919..5834366e5 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs @@ -51,13 +51,11 @@ public static byte[] SerializePublicKey(AsymmetricKeyParameter publicKey) return publicKeyInfo.ToAsn1Object().GetDerEncoded(); } - public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment pubKey) - { + public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment pubKey) => // And then we do this to deserialize from the DER (from above) // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted // to a byte[] first and then shoved through a MemoryStream - return PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false)); - } + PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false)); public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey) { @@ -66,13 +64,11 @@ public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey) return privateKeyInfo.ToAsn1Object().GetDerEncoded(); } - public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment privateKey) - { + public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment privateKey) => // And then we do this to deserialize from the DER (from above) // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted // to a byte[] first and then shoved through a MemoryStream - return PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false)); - } + PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false)); public static string PubKeyFingerprint(ArraySegment publicKeyBytes) { @@ -90,7 +86,7 @@ public void SaveToFile(string path) { PublicKeyFingerprint = PublicKeyFingerprint, PublicKey = Convert.ToBase64String(PublicKeySerialized), - PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)), + PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)) }); File.WriteAllText(path, json); } @@ -104,9 +100,7 @@ public static EncryptionCredentials LoadFromFile(string path) byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey); if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment(publicKeyBytes))) - { throw new Exception("Saved public key fingerprint does not match public key."); - } return new EncryptionCredentials { PublicKeySerialized = publicKeyBytes, @@ -115,7 +109,7 @@ public static EncryptionCredentials LoadFromFile(string path) }; } - private class SerializedPair + class SerializedPair { public string PublicKeyFingerprint; public string PublicKey; diff --git a/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs index 1357dce9b..30a0f539b 100644 --- a/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs +++ b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs @@ -12,28 +12,27 @@ public class EncryptionTransport : Transport, PortTransport { public override bool IsEncrypted => true; public override string EncryptionCipher => "AES256-GCM"; - public Transport inner; + [FormerlySerializedAs("inner")] + public Transport Inner; public ushort Port { get { - if (inner is PortTransport portTransport) - { + if (Inner is PortTransport portTransport) return portTransport.Port; - } - Debug.LogError($"EncryptionTransport can't get Port because {inner} is not a PortTransport"); + Debug.LogError($"EncryptionTransport can't get Port because {Inner} is not a PortTransport"); return 0; } set { - if (inner is PortTransport portTransport) + if (Inner is PortTransport portTransport) { portTransport.Port = value; return; } - Debug.LogError($"EncryptionTransport can't set Port because {inner} is not a PortTransport"); + Debug.LogError($"EncryptionTransport can't set Port because {Inner} is not a PortTransport"); } } @@ -41,81 +40,78 @@ public enum ValidationMode { Off, List, - Callback, + Callback } - public ValidationMode clientValidateServerPubKey; + [FormerlySerializedAs("clientValidateServerPubKey")] + public ValidationMode ClientValidateServerPubKey; + [FormerlySerializedAs("clientTrustedPubKeySignatures")] [Tooltip("List of public key fingerprints the client will accept")] - public string[] clientTrustedPubKeySignatures; - public Func onClientValidateServerPubKey; - public bool serverLoadKeyPairFromFile; - public string serverKeypairPath = "./server-keys.json"; + public string[] ClientTrustedPubKeySignatures; + public Func OnClientValidateServerPubKey; + [FormerlySerializedAs("serverLoadKeyPairFromFile")] + public bool ServerLoadKeyPairFromFile; + [FormerlySerializedAs("serverKeypairPath")] + public string ServerKeypairPath = "./server-keys.json"; - private EncryptedConnection _client; + EncryptedConnection client; - private Dictionary _serverConnections = new Dictionary(); + readonly Dictionary serverConnections = new Dictionary(); - private List _serverPendingConnections = + readonly List serverPendingConnections = new List(); - private EncryptionCredentials _credentials; - public string EncryptionPublicKeyFingerprint => _credentials?.PublicKeyFingerprint; - public byte[] EncryptionPublicKey => _credentials?.PublicKeySerialized; + EncryptionCredentials credentials; + public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint; + public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized; - private void ServerRemoveFromPending(EncryptedConnection con) + void ServerRemoveFromPending(EncryptedConnection con) { - for (int i = 0; i < _serverPendingConnections.Count; i++) - { - if (_serverPendingConnections[i] == con) + for (int i = 0; i < serverPendingConnections.Count; i++) + if (serverPendingConnections[i] == con) { // remove by swapping with last - int lastIndex = _serverPendingConnections.Count - 1; - _serverPendingConnections[i] = _serverPendingConnections[lastIndex]; - _serverPendingConnections.RemoveAt(lastIndex); + int lastIndex = serverPendingConnections.Count - 1; + serverPendingConnections[i] = serverPendingConnections[lastIndex]; + serverPendingConnections.RemoveAt(lastIndex); break; } - } } - private void HandleInnerServerDisconnected(int connId) + void HandleInnerServerDisconnected(int connId) { - if (_serverConnections.TryGetValue(connId, out EncryptedConnection con)) + if (serverConnections.TryGetValue(connId, out EncryptedConnection con)) { ServerRemoveFromPending(con); - _serverConnections.Remove(connId); + serverConnections.Remove(connId); } OnServerDisconnected?.Invoke(connId); } - private void HandleInnerServerError(int connId, TransportError type, string msg) - { - OnServerError?.Invoke(connId, type, $"inner: {msg}"); - } + void HandleInnerServerError(int connId, TransportError type, string msg) => OnServerError?.Invoke(connId, type, $"inner: {msg}"); - private void HandleInnerServerDataReceived(int connId, ArraySegment data, int channel) + void HandleInnerServerDataReceived(int connId, ArraySegment data, int channel) { - if (_serverConnections.TryGetValue(connId, out EncryptedConnection c)) - { + if (serverConnections.TryGetValue(connId, out EncryptedConnection c)) c.OnReceiveRaw(data, channel); - } } - private void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, inner.ServerGetClientAddress(connId)); + void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId)); - private void HandleInnerServerConnected(int connId, string clientRemoteAddress) + void HandleInnerServerConnected(int connId, string clientRemoteAddress) { - Debug.Log($"[EncryptionTransport] New connection #{connId}"); + Debug.Log($"[EncryptionTransport] New connection #{connId} from {clientRemoteAddress}"); EncryptedConnection ec = null; ec = new EncryptedConnection( - _credentials, + credentials, false, - (segment, channel) => inner.ServerSend(connId, segment, channel), + (segment, channel) => Inner.ServerSend(connId, segment, channel), (segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel), () => { Debug.Log($"[EncryptionTransport] Connection #{connId} is ready"); + // ReSharper disable once AccessToModifiedClosure ServerRemoveFromPending(ec); - //OnServerConnected?.Invoke(connId); OnServerConnectedWithAddress?.Invoke(connId, clientRemoteAddress); }, (type, msg) => @@ -123,32 +119,25 @@ private void HandleInnerServerConnected(int connId, string clientRemoteAddress) OnServerError?.Invoke(connId, type, msg); ServerDisconnect(connId); }); - _serverConnections.Add(connId, ec); - _serverPendingConnections.Add(ec); + serverConnections.Add(connId, ec); + serverPendingConnections.Add(ec); } - private void HandleInnerClientDisconnected() + void HandleInnerClientDisconnected() { - _client = null; + client = null; OnClientDisconnected?.Invoke(); } - private void HandleInnerClientError(TransportError arg1, string arg2) - { - OnClientError?.Invoke(arg1, $"inner: {arg2}"); - } + void HandleInnerClientError(TransportError arg1, string arg2) => OnClientError?.Invoke(arg1, $"inner: {arg2}"); - private void HandleInnerClientDataReceived(ArraySegment data, int channel) - { - _client?.OnReceiveRaw(data, channel); - } + void HandleInnerClientDataReceived(ArraySegment data, int channel) => client?.OnReceiveRaw(data, channel); - private void HandleInnerClientConnected() - { - _client = new EncryptedConnection( - _credentials, + void HandleInnerClientConnected() => + client = new EncryptedConnection( + credentials, true, - (segment, channel) => inner.ClientSend(segment, channel), + (segment, channel) => Inner.ClientSend(segment, channel), (segment, channel) => OnClientDataReceived?.Invoke(segment, channel), () => { @@ -160,25 +149,23 @@ private void HandleInnerClientConnected() ClientDisconnect(); }, HandleClientValidateServerPubKey); - } - private bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) + bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) { - switch (clientValidateServerPubKey) + switch (ClientValidateServerPubKey) { case ValidationMode.Off: return true; case ValidationMode.List: - return Array.IndexOf(clientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; + return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; case ValidationMode.Callback: - return onClientValidateServerPubKey(pubKeyInfo); + return OnClientValidateServerPubKey(pubKeyInfo); default: throw new ArgumentOutOfRangeException(); } } - void Awake() - { + void Awake() => // check if encryption via hardware acceleration is supported. // this can be useful to know for low end devices. // @@ -188,27 +175,26 @@ void Awake() // https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs // which Unity does not support yet. Debug.Log($"EncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}"); - } - public override bool Available() => inner.Available(); + public override bool Available() => Inner.Available(); - public override bool ClientConnected() => _client != null && _client.IsReady; + public override bool ClientConnected() => client != null && client.IsReady; public override void ClientConnect(string address) { - switch (clientValidateServerPubKey) + switch (ClientValidateServerPubKey) { case ValidationMode.Off: break; case ValidationMode.List: - if (clientTrustedPubKeySignatures == null || clientTrustedPubKeySignatures.Length == 0) + if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0) { OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty."); return; } break; case ValidationMode.Callback: - if (onClientValidateServerPubKey == null) + if (OnClientValidateServerPubKey == null) { OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set"); return; @@ -217,95 +203,79 @@ public override void ClientConnect(string address) default: throw new ArgumentOutOfRangeException(); } - _credentials = EncryptionCredentials.Generate(); - inner.OnClientConnected = HandleInnerClientConnected; - inner.OnClientDataReceived = HandleInnerClientDataReceived; - inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel); - inner.OnClientError = HandleInnerClientError; - inner.OnClientDisconnected = HandleInnerClientDisconnected; - inner.ClientConnect(address); + credentials = EncryptionCredentials.Generate(); + Inner.OnClientConnected = HandleInnerClientConnected; + Inner.OnClientDataReceived = HandleInnerClientDataReceived; + Inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel); + Inner.OnClientError = HandleInnerClientError; + Inner.OnClientDisconnected = HandleInnerClientDisconnected; + Inner.ClientConnect(address); } public override void ClientSend(ArraySegment segment, int channelId = Channels.Reliable) => - _client?.Send(segment, channelId); + client?.Send(segment, channelId); - public override void ClientDisconnect() => inner.ClientDisconnect(); + public override void ClientDisconnect() => Inner.ClientDisconnect(); - public override Uri ServerUri() => inner.ServerUri(); + public override Uri ServerUri() => Inner.ServerUri(); - public override bool ServerActive() => inner.ServerActive(); + public override bool ServerActive() => Inner.ServerActive(); public override void ServerStart() { - if (serverLoadKeyPairFromFile) - { - _credentials = EncryptionCredentials.LoadFromFile(serverKeypairPath); - } + if (ServerLoadKeyPairFromFile) + credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath); else - { - _credentials = EncryptionCredentials.Generate(); - } + credentials = EncryptionCredentials.Generate(); #pragma warning disable CS0618 // Type or member is obsolete - inner.OnServerConnected = HandleInnerServerConnected; + Inner.OnServerConnected = HandleInnerServerConnected; #pragma warning restore CS0618 // Type or member is obsolete - inner.OnServerConnectedWithAddress = HandleInnerServerConnected; - inner.OnServerDataReceived = HandleInnerServerDataReceived; - inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel); - inner.OnServerError = HandleInnerServerError; - inner.OnServerDisconnected = HandleInnerServerDisconnected; - inner.ServerStart(); + Inner.OnServerConnectedWithAddress = HandleInnerServerConnected; + Inner.OnServerDataReceived = HandleInnerServerDataReceived; + Inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel); + Inner.OnServerError = HandleInnerServerError; + Inner.OnServerDisconnected = HandleInnerServerDisconnected; + Inner.ServerStart(); } public override void ServerSend(int connectionId, ArraySegment segment, int channelId = Channels.Reliable) { - if (_serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) - { + if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) connection.Send(segment, channelId); - } } - public override void ServerDisconnect(int connectionId) - { + public override void ServerDisconnect(int connectionId) => // cleanup is done via inners disconnect event - inner.ServerDisconnect(connectionId); - } + Inner.ServerDisconnect(connectionId); - public override string ServerGetClientAddress(int connectionId) => inner.ServerGetClientAddress(connectionId); + public override string ServerGetClientAddress(int connectionId) => Inner.ServerGetClientAddress(connectionId); - public override void ServerStop() => inner.ServerStop(); + public override void ServerStop() => Inner.ServerStop(); public override int GetMaxPacketSize(int channelId = Channels.Reliable) => - inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; + Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; - public override void Shutdown() => inner.Shutdown(); + public override void Shutdown() => Inner.Shutdown(); - public override void ClientEarlyUpdate() - { - inner.ClientEarlyUpdate(); - } + public override void ClientEarlyUpdate() => Inner.ClientEarlyUpdate(); public override void ClientLateUpdate() { - inner.ClientLateUpdate(); + Inner.ClientLateUpdate(); Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); - _client?.TickNonReady(NetworkTime.localTime); + client?.TickNonReady(NetworkTime.localTime); Profiler.EndSample(); } - public override void ServerEarlyUpdate() - { - inner.ServerEarlyUpdate(); - } + public override void ServerEarlyUpdate() => Inner.ServerEarlyUpdate(); public override void ServerLateUpdate() { - inner.ServerLateUpdate(); + Inner.ServerLateUpdate(); Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); // Reverse iteration as entries can be removed while updating - for (int i = _serverPendingConnections.Count - 1; i >= 0; i--) - { - _serverPendingConnections[i].TickNonReady(NetworkTime.time); - } + for (int i = serverPendingConnections.Count - 1; i >= 0; i--) + serverPendingConnections[i].TickNonReady(NetworkTime.time); Profiler.EndSample(); } }