fix: EncryptionTransport uses KDF to ensure fixed size key (#3773)

* reset _time

* fix: EncryptionTransport use KDF to ensure fixed size key

* Expose pub key through transport

* Old Unity compat
This commit is contained in:
Robin Rolf 2024-03-05 17:13:09 +00:00 committed by MrGadget
parent d929a9a9ea
commit 8f0a952d16
3 changed files with 55 additions and 11 deletions

View File

@ -51,6 +51,7 @@ public void Setup()
clientValidateKey = null; clientValidateKey = null;
clientRecv.Clear(); clientRecv.Clear();
serverRecv.Clear(); serverRecv.Clear();
_time = 0;
serverCreds = EncryptionCredentials.Generate(); serverCreds = EncryptionCredentials.Generate();
server = new EncryptedConnection(serverCreds, false, server = new EncryptedConnection(serverCreds, false,
@ -119,6 +120,7 @@ private void Pump()
server.TickNonReady(_time); server.TickNonReady(_time);
} }
} }
[TearDown] [TearDown]
public void TearDown() public void TearDown()
{ {

View File

@ -1,8 +1,10 @@
using System; using System;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text;
using Org.BouncyCastle.Crypto; using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Agreement; using Org.BouncyCastle.Crypto.Agreement;
using Org.BouncyCastle.Crypto.Engines; using Org.BouncyCastle.Crypto.Digests;
using Org.BouncyCastle.Crypto.Generators;
using Org.BouncyCastle.Crypto.Modes; using Org.BouncyCastle.Crypto.Modes;
using Org.BouncyCastle.Crypto.Parameters; using Org.BouncyCastle.Crypto.Parameters;
using UnityEngine.Profiling; using UnityEngine.Profiling;
@ -11,6 +13,14 @@ namespace Mirror.Transports.Encryption
{ {
public class EncryptedConnection public class EncryptedConnection
{ {
// 256-bit key
private const int KeyLength = 32;
// 512-bit salt for the key derivation function
private const int HkdfSaltSize = KeyLength * 2;
// Info tag for the HKDF, this just adds more entropy
private static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport");
// fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed) // fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed)
private const int NonceSize = 12; private const int NonceSize = 12;
@ -40,9 +50,14 @@ public class EncryptedConnection
// use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core) // use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core)
private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine());
// Set up a global HKDF with a SHA-256 digest
private static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest());
// Global byte array to store nonce sent by the remote side, they're used immediately after // Global byte array to store nonce sent by the remote side, they're used immediately after
private static readonly byte[] ReceiveNonce = new byte[NonceSize]; private static readonly byte[] ReceiveNonce = new byte[NonceSize];
// Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes*
private static byte[] _tmpRemoteSaltBuffer = new byte[HkdfSaltSize];
// buffer for encrypt/decrypt operations, resized larger as needed // buffer for encrypt/decrypt operations, resized larger as needed
// this is also the buffer that will be returned to mirror via ArraySegment // this is also the buffer that will be returned to mirror via ArraySegment
// so any thread safety concerns would need to take extra care here // so any thread safety concerns would need to take extra care here
@ -93,6 +108,7 @@ enum State
private Func<PubKeyInfo, bool> _validateRemoteKey; private Func<PubKeyInfo, bool> _validateRemoteKey;
// Our asymmetric credentials for the initial DH exchange // Our asymmetric credentials for the initial DH exchange
private EncryptionCredentials _credentials; private EncryptionCredentials _credentials;
private byte[] _hkdfSalt;
// After no handshake packet in this many seconds, the handshake fails // After no handshake packet in this many seconds, the handshake fails
private double _handshakeTimeout; private double _handshakeTimeout;
@ -126,6 +142,11 @@ public EncryptedConnection(EncryptionCredentials credentials,
{ {
_credentials = credentials; _credentials = credentials;
_sendsFirst = isClient; _sendsFirst = isClient;
if (!_sendsFirst)
{
// salt is controlled by the server
_hkdfSalt = GenerateSecureBytes(HkdfSaltSize);
}
_send = sendAction; _send = sendAction;
_receive = receiveAction; _receive = receiveAction;
_ready = readyAction; _ready = readyAction;
@ -134,15 +155,15 @@ public EncryptedConnection(EncryptionCredentials credentials,
} }
// Generates a random starting nonce // Generates a random starting nonce
private static byte[] GenerateStartingNonce() private static byte[] GenerateSecureBytes(int size)
{ {
byte[] nonce = new byte[NonceSize]; byte[] bytes = new byte[size];
using (RandomNumberGenerator rng = RandomNumberGenerator.Create()) using (RandomNumberGenerator rng = RandomNumberGenerator.Create())
{ {
rng.GetBytes(nonce); rng.GetBytes(bytes);
} }
return nonce; return bytes;
} }
public void OnReceiveRaw(ArraySegment<byte> data, int channel) public void OnReceiveRaw(ArraySegment<byte> data, int channel)
@ -203,7 +224,7 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
_state = State.WaitingHandshakeReply; _state = State.WaitingHandshakeReply;
ResetTimeouts(); ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining)); CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _hkdfSalt);
SendHandshakeAndPubKey(OpCodes.HandshakeAck); SendHandshakeAndPubKey(OpCodes.HandshakeAck);
break; break;
case OpCodes.HandshakeAck: case OpCodes.HandshakeAck:
@ -228,7 +249,8 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
_state = State.WaitingHandshakeReply; _state = State.WaitingHandshakeReply;
ResetTimeouts(); ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining)); reader.ReadBytes(_tmpRemoteSaltBuffer, HkdfSaltSize);
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _tmpRemoteSaltBuffer);
SendHandshakeFin(); SendHandshakeFin();
break; break;
case OpCodes.HandshakeFin: case OpCodes.HandshakeFin:
@ -421,6 +443,10 @@ private void SendHandshakeAndPubKey(OpCodes opcode)
using (NetworkWriterPooled writer = NetworkWriterPool.Get()) using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{ {
writer.WriteByte((byte)opcode); writer.WriteByte((byte)opcode);
if (opcode == OpCodes.HandshakeAck)
{
writer.WriteBytes(_hkdfSalt, 0, HkdfSaltSize);
}
writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length); writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length);
_send(writer.ToArraySegment(), Channels.Unreliable); _send(writer.ToArraySegment(), Channels.Unreliable);
} }
@ -435,7 +461,7 @@ private void SendHandshakeFin()
} }
} }
private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw) private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
{ {
AsymmetricKeyParameter remotePubKey; AsymmetricKeyParameter remotePubKey;
try try
@ -468,10 +494,10 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw)
// It's like magic, but with math! // It's like magic, but with math!
ECDHBasicAgreement ecdh = new ECDHBasicAgreement(); ECDHBasicAgreement ecdh = new ECDHBasicAgreement();
ecdh.Init(_credentials.PrivateKey); ecdh.Init(_credentials.PrivateKey);
byte[] keyRaw; byte[] sharedSecret;
try try
{ {
keyRaw = ecdh.CalculateAgreement(remotePubKey).ToByteArrayUnsigned(); sharedSecret = ecdh.CalculateAgreement(remotePubKey).ToByteArrayUnsigned();
} }
catch catch
(Exception e) (Exception e)
@ -480,10 +506,24 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw)
return; return;
} }
if (salt.Length != HkdfSaltSize)
{
_error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}.");
return;
}
Hkdf.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo));
// Allocate a buffer for the output key
byte[] keyRaw = new byte[KeyLength];
// Generate the output keying material
Hkdf.GenerateBytes(keyRaw, 0, keyRaw.Length);
KeyParameter key = new KeyParameter(keyRaw); KeyParameter key = new KeyParameter(keyRaw);
// generate a starting nonce // generate a starting nonce
_nonce = GenerateStartingNonce(); _nonce = GenerateSecureBytes(NonceSize);
// we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance // we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance
// instead of creating a new one each encrypt/decrypt // instead of creating a new one each encrypt/decrypt

View File

@ -32,6 +32,8 @@ public enum ValidationMode
new List<EncryptedConnection>(); new List<EncryptedConnection>();
private EncryptionCredentials _credentials; private EncryptionCredentials _credentials;
public string EncryptionPublicKeyFingerprint => _credentials?.PublicKeyFingerprint;
public byte[] EncryptionPublicKey => _credentials?.PublicKeySerialized;
private void ServerRemoveFromPending(EncryptedConnection con) private void ServerRemoveFromPending(EncryptedConnection con)
{ {