fix: EncryptionTransport uses KDF to ensure fixed size key (#3773)

* reset _time

* fix: EncryptionTransport use KDF to ensure fixed size key

* Expose pub key through transport

* Old Unity compat
This commit is contained in:
Robin Rolf 2024-03-05 17:13:09 +00:00 committed by MrGadget
parent d929a9a9ea
commit 8f0a952d16
3 changed files with 55 additions and 11 deletions

View File

@ -51,6 +51,7 @@ public void Setup()
clientValidateKey = null;
clientRecv.Clear();
serverRecv.Clear();
_time = 0;
serverCreds = EncryptionCredentials.Generate();
server = new EncryptedConnection(serverCreds, false,
@ -119,6 +120,7 @@ private void Pump()
server.TickNonReady(_time);
}
}
[TearDown]
public void TearDown()
{

View File

@ -1,8 +1,10 @@
using System;
using System.Security.Cryptography;
using System.Text;
using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Agreement;
using Org.BouncyCastle.Crypto.Engines;
using Org.BouncyCastle.Crypto.Digests;
using Org.BouncyCastle.Crypto.Generators;
using Org.BouncyCastle.Crypto.Modes;
using Org.BouncyCastle.Crypto.Parameters;
using UnityEngine.Profiling;
@ -11,6 +13,14 @@ namespace Mirror.Transports.Encryption
{
public class EncryptedConnection
{
// 256-bit key
private const int KeyLength = 32;
// 512-bit salt for the key derivation function
private const int HkdfSaltSize = KeyLength * 2;
// Info tag for the HKDF, this just adds more entropy
private static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport");
// fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed)
private const int NonceSize = 12;
@ -40,9 +50,14 @@ public class EncryptedConnection
// use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core)
private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine());
// Set up a global HKDF with a SHA-256 digest
private static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest());
// Global byte array to store nonce sent by the remote side, they're used immediately after
private static readonly byte[] ReceiveNonce = new byte[NonceSize];
// Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes*
private static byte[] _tmpRemoteSaltBuffer = new byte[HkdfSaltSize];
// buffer for encrypt/decrypt operations, resized larger as needed
// this is also the buffer that will be returned to mirror via ArraySegment
// so any thread safety concerns would need to take extra care here
@ -93,6 +108,7 @@ enum State
private Func<PubKeyInfo, bool> _validateRemoteKey;
// Our asymmetric credentials for the initial DH exchange
private EncryptionCredentials _credentials;
private byte[] _hkdfSalt;
// After no handshake packet in this many seconds, the handshake fails
private double _handshakeTimeout;
@ -126,6 +142,11 @@ public EncryptedConnection(EncryptionCredentials credentials,
{
_credentials = credentials;
_sendsFirst = isClient;
if (!_sendsFirst)
{
// salt is controlled by the server
_hkdfSalt = GenerateSecureBytes(HkdfSaltSize);
}
_send = sendAction;
_receive = receiveAction;
_ready = readyAction;
@ -134,15 +155,15 @@ public EncryptedConnection(EncryptionCredentials credentials,
}
// Generates a random starting nonce
private static byte[] GenerateStartingNonce()
private static byte[] GenerateSecureBytes(int size)
{
byte[] nonce = new byte[NonceSize];
byte[] bytes = new byte[size];
using (RandomNumberGenerator rng = RandomNumberGenerator.Create())
{
rng.GetBytes(nonce);
rng.GetBytes(bytes);
}
return nonce;
return bytes;
}
public void OnReceiveRaw(ArraySegment<byte> data, int channel)
@ -203,7 +224,7 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
_state = State.WaitingHandshakeReply;
ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining));
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _hkdfSalt);
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
break;
case OpCodes.HandshakeAck:
@ -228,7 +249,8 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
_state = State.WaitingHandshakeReply;
ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining));
reader.ReadBytes(_tmpRemoteSaltBuffer, HkdfSaltSize);
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _tmpRemoteSaltBuffer);
SendHandshakeFin();
break;
case OpCodes.HandshakeFin:
@ -421,6 +443,10 @@ private void SendHandshakeAndPubKey(OpCodes opcode)
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{
writer.WriteByte((byte)opcode);
if (opcode == OpCodes.HandshakeAck)
{
writer.WriteBytes(_hkdfSalt, 0, HkdfSaltSize);
}
writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length);
_send(writer.ToArraySegment(), Channels.Unreliable);
}
@ -435,7 +461,7 @@ private void SendHandshakeFin()
}
}
private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw)
private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
{
AsymmetricKeyParameter remotePubKey;
try
@ -468,10 +494,10 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw)
// It's like magic, but with math!
ECDHBasicAgreement ecdh = new ECDHBasicAgreement();
ecdh.Init(_credentials.PrivateKey);
byte[] keyRaw;
byte[] sharedSecret;
try
{
keyRaw = ecdh.CalculateAgreement(remotePubKey).ToByteArrayUnsigned();
sharedSecret = ecdh.CalculateAgreement(remotePubKey).ToByteArrayUnsigned();
}
catch
(Exception e)
@ -480,10 +506,24 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw)
return;
}
if (salt.Length != HkdfSaltSize)
{
_error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}.");
return;
}
Hkdf.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo));
// Allocate a buffer for the output key
byte[] keyRaw = new byte[KeyLength];
// Generate the output keying material
Hkdf.GenerateBytes(keyRaw, 0, keyRaw.Length);
KeyParameter key = new KeyParameter(keyRaw);
// generate a starting nonce
_nonce = GenerateStartingNonce();
_nonce = GenerateSecureBytes(NonceSize);
// we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance
// instead of creating a new one each encrypt/decrypt

View File

@ -32,6 +32,8 @@ public enum ValidationMode
new List<EncryptedConnection>();
private EncryptionCredentials _credentials;
public string EncryptionPublicKeyFingerprint => _credentials?.PublicKeyFingerprint;
public byte[] EncryptionPublicKey => _credentials?.PublicKeySerialized;
private void ServerRemoveFromPending(EncryptedConnection con)
{