code style

This commit is contained in:
Robin Rolf 2024-10-26 16:29:19 +00:00
parent d347600468
commit 7ebbaa0319
5 changed files with 227 additions and 310 deletions

View File

@ -22,7 +22,7 @@ public void Setup()
GameObject gameObject = new GameObject();
encryption = gameObject.AddComponent<EncryptionTransport>();
encryption.inner = inner;
encryption.Inner = inner;
}
[TearDown]

View File

@ -17,11 +17,11 @@ public class EncryptionTransportInspector : UnityEditor.Editor
void OnEnable()
{
innerProperty = serializedObject.FindProperty("inner");
clientValidatesServerPubKeyProperty = serializedObject.FindProperty("clientValidateServerPubKey");
clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("clientTrustedPubKeySignatures");
serverKeypairPathProperty = serializedObject.FindProperty("serverKeypairPath");
serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("serverLoadKeyPairFromFile");
innerProperty = serializedObject.FindProperty("Inner");
clientValidatesServerPubKeyProperty = serializedObject.FindProperty("ClientValidateServerPubKey");
clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("ClientTrustedPubKeySignatures");
serverKeypairPathProperty = serializedObject.FindProperty("ServerKeypairPath");
serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("ServerLoadKeyPairFromFile");
}
public override void OnInspectorGUI()

View File

@ -14,32 +14,32 @@ namespace Mirror.Transports.Encryption
public class EncryptedConnection
{
// 256-bit key
private const int KeyLength = 32;
const int KeyLength = 32;
// 512-bit salt for the key derivation function
private const int HkdfSaltSize = KeyLength * 2;
const int HkdfSaltSize = KeyLength * 2;
// Info tag for the HKDF, this just adds more entropy
private static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport");
static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport");
// fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed)
private const int NonceSize = 12;
const int NonceSize = 12;
// this is the size of the "checksum" included in each encrypted payload
// 16 bytes/128 bytes is the recommended value for best security
// can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker.
// Setting it lower than 12 bytes is not recommended
private const int MacSizeBytes = 16;
const int MacSizeBytes = 16;
private const int MacSizeBits = MacSizeBytes * 8;
const int MacSizeBits = MacSizeBytes * 8;
// How much metadata overhead we have for regular packets
public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize;
// After how many seconds of not receiving a handshake packet we should time out
private const double DurationTimeout = 2; // 2s
const double DurationTimeout = 2; // 2s
// After how many seconds to assume the last handshake packet got lost and to resend another one
private const double DurationResend = 0.05; // 50ms
const double DurationResend = 0.05; // 50ms
// Static fields for allocation efficiency, makes this not thread safe
@ -48,20 +48,20 @@ public class EncryptedConnection
// Set up a global cipher instance, it is initialised/reset before use
// (AesFastEngine used to exist, but was removed due to side channel issues)
// use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core)
private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine());
static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine());
// Set up a global HKDF with a SHA-256 digest
private static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest());
static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest());
// Global byte array to store nonce sent by the remote side, they're used immediately after
private static readonly byte[] ReceiveNonce = new byte[NonceSize];
static readonly byte[] ReceiveNonce = new byte[NonceSize];
// Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes*
private static byte[] _tmpRemoteSaltBuffer = new byte[HkdfSaltSize];
static readonly byte[] TMPRemoteSaltBuffer = new byte[HkdfSaltSize];
// buffer for encrypt/decrypt operations, resized larger as needed
// this is also the buffer that will be returned to mirror via ArraySegment
// so any thread safety concerns would need to take extra care here
private static byte[] _tmpCryptBuffer = new byte[2048];
static byte[] TMPCryptBuffer = new byte[2048];
// packet headers
enum OpCodes : byte
@ -70,7 +70,7 @@ enum OpCodes : byte
Data = 1,
HandshakeStart = 2,
HandshakeAck = 3,
HandshakeFin = 4,
HandshakeFin = 4
}
enum State
@ -91,37 +91,37 @@ enum State
Ready
}
private State _state = State.WaitingHandshake;
State state = State.WaitingHandshake;
// Key exchange confirmed and data can be sent freely
public bool IsReady => _state == State.Ready;
public bool IsReady => state == State.Ready;
// Callback to send off encrypted data
private Action<ArraySegment<byte>, int> _send;
readonly Action<ArraySegment<byte>, int> send;
// Callback when received data has been decrypted
private Action<ArraySegment<byte>, int> _receive;
readonly Action<ArraySegment<byte>, int> receive;
// Callback when the connection becomes ready
private Action _ready;
readonly Action ready;
// On-error callback, disconnect expected
private Action<TransportError, string> _error;
readonly Action<TransportError, string> error;
// Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance
// (usually client validates the server key)
private Func<PubKeyInfo, bool> _validateRemoteKey;
readonly Func<PubKeyInfo, bool> validateRemoteKey;
// Our asymmetric credentials for the initial DH exchange
private EncryptionCredentials _credentials;
private byte[] _hkdfSalt;
EncryptionCredentials credentials;
readonly byte[] hkdfSalt;
// After no handshake packet in this many seconds, the handshake fails
private double _handshakeTimeout;
double handshakeTimeout;
// When to assume the last handshake packet got lost and to resend another one
private double _nextHandshakeResend;
double nextHandshakeResend;
// we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in
// so we can update it without creating a new AeadParameters instance
// this might break in the future! (will cause bad data)
private byte[] _nonce = new byte[NonceSize];
private AeadParameters _cipherParametersEncrypt;
private AeadParameters _cipherParametersDecrypt;
byte[] nonce = new byte[NonceSize];
AeadParameters cipherParametersEncrypt;
AeadParameters cipherParametersDecrypt;
/*
@ -130,7 +130,7 @@ enum State
*
* The client does this, since the fin is not acked explicitly, but by receiving data to decrypt
*/
private readonly bool _sendsFirst;
readonly bool sendsFirst;
public EncryptedConnection(EncryptionCredentials credentials,
bool isClient,
@ -140,28 +140,24 @@ public EncryptedConnection(EncryptionCredentials credentials,
Action<TransportError, string> errorAction,
Func<PubKeyInfo, bool> validateRemoteKey = null)
{
_credentials = credentials;
_sendsFirst = isClient;
if (!_sendsFirst)
{
this.credentials = credentials;
sendsFirst = isClient;
if (!sendsFirst)
// salt is controlled by the server
_hkdfSalt = GenerateSecureBytes(HkdfSaltSize);
}
_send = sendAction;
_receive = receiveAction;
_ready = readyAction;
_error = errorAction;
_validateRemoteKey = validateRemoteKey;
hkdfSalt = GenerateSecureBytes(HkdfSaltSize);
send = sendAction;
receive = receiveAction;
ready = readyAction;
error = errorAction;
this.validateRemoteKey = validateRemoteKey;
}
// Generates a random starting nonce
private static byte[] GenerateSecureBytes(int size)
static byte[] GenerateSecureBytes(int size)
{
byte[] bytes = new byte[size];
using (RandomNumberGenerator rng = RandomNumberGenerator.Create())
{
rng.GetBytes(bytes);
}
return bytes;
}
@ -170,7 +166,7 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
{
if (data.Count < 1)
{
_error(TransportError.Unexpected, "Received empty packet");
error(TransportError.Unexpected, "Received empty packet");
return;
}
@ -181,18 +177,14 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
{
case OpCodes.Data:
// first sender ready is implicit when data is received
if (_sendsFirst && _state == State.WaitingHandshakeReply)
{
if (sendsFirst && state == State.WaitingHandshakeReply)
SetReady();
}
else if (!IsReady)
{
_error(TransportError.Unexpected, "Unexpected data while not ready.");
}
error(TransportError.Unexpected, "Unexpected data while not ready.");
if (reader.Remaining < Overhead)
{
_error(TransportError.Unexpected, "received data packet smaller than metadata size");
error(TransportError.Unexpected, "received data packet smaller than metadata size");
return;
}
@ -203,72 +195,62 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
ArraySegment<byte> plaintext = Decrypt(ciphertext);
Profiler.EndSample();
if (plaintext.Count == 0)
{
// error
return;
}
_receive(plaintext, channel);
receive(plaintext, channel);
break;
case OpCodes.HandshakeStart:
if (_sendsFirst)
if (sendsFirst)
{
_error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this.");
error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this.");
return;
}
if (_state == State.WaitingHandshakeReply)
{
if (state == State.WaitingHandshakeReply)
// this is fine, packets may arrive out of order
return;
}
_state = State.WaitingHandshakeReply;
state = State.WaitingHandshakeReply;
ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _hkdfSalt);
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), hkdfSalt);
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
break;
case OpCodes.HandshakeAck:
if (!_sendsFirst)
if (!sendsFirst)
{
_error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this.");
error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this.");
return;
}
if (IsReady)
{
// this is fine, packets may arrive out of order
return;
}
if (_state == State.WaitingHandshakeReply)
{
if (state == State.WaitingHandshakeReply)
// this is fine, packets may arrive out of order
return;
}
_state = State.WaitingHandshakeReply;
state = State.WaitingHandshakeReply;
ResetTimeouts();
reader.ReadBytes(_tmpRemoteSaltBuffer, HkdfSaltSize);
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _tmpRemoteSaltBuffer);
reader.ReadBytes(TMPRemoteSaltBuffer, HkdfSaltSize);
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer);
SendHandshakeFin();
break;
case OpCodes.HandshakeFin:
if (_sendsFirst)
if (sendsFirst)
{
_error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this.");
error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this.");
return;
}
if (IsReady)
{
// this is fine, packets may arrive out of order
return;
}
if (_state != State.WaitingHandshakeReply)
if (state != State.WaitingHandshakeReply)
{
_error(TransportError.Unexpected,
error(TransportError.Unexpected,
"Received HandshakeFin packet, we didn't expect this yet.");
return;
}
@ -277,24 +259,25 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
break;
default:
_error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}");
error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}");
break;
}
}
}
private void SetReady()
void SetReady()
{
// done with credentials, null out the reference
_credentials = null;
credentials = null;
_state = State.Ready;
_ready();
state = State.Ready;
ready();
}
private void ResetTimeouts()
void ResetTimeouts()
{
_handshakeTimeout = 0;
_nextHandshakeResend = -1;
handshakeTimeout = 0;
nextHandshakeResend = -1;
}
public void Send(ArraySegment<byte> data, int channel)
@ -307,161 +290,143 @@ public void Send(ArraySegment<byte> data, int channel)
Profiler.EndSample();
if (encrypted.Count == 0)
{
// error
return;
}
writer.WriteBytes(encrypted.Array, 0, encrypted.Count);
// write nonce after since Encrypt will update it
writer.WriteBytes(_nonce, 0, NonceSize);
_send(writer.ToArraySegment(), channel);
writer.WriteBytes(nonce, 0, NonceSize);
send(writer.ToArraySegment(), channel);
}
}
private ArraySegment<byte> Encrypt(ArraySegment<byte> plaintext)
ArraySegment<byte> Encrypt(ArraySegment<byte> plaintext)
{
if (plaintext.Count == 0)
{
// Invalid
return new ArraySegment<byte>();
}
// Need to make the nonce unique again before encrypting another message
UpdateNonce();
// Re-initialize the cipher with our cached parameters
Cipher.Init(true, _cipherParametersEncrypt);
Cipher.Init(true, cipherParametersEncrypt);
// Calculate the expected output size, this should always be input size + mac size
int outSize = Cipher.GetOutputSize(plaintext.Count);
#if UNITY_EDITOR
// expecting the outSize to be input size + MacSize
if (outSize != plaintext.Count + MacSizeBytes)
{
throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}");
}
#endif
// Resize the static buffer to fit
EnsureSize(ref _tmpCryptBuffer, outSize);
EnsureSize(ref TMPCryptBuffer, outSize);
int resultLen;
try
{
// Run the plain text through the cipher, ProcessBytes will only process full blocks
resultLen =
Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, _tmpCryptBuffer, 0);
Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, TMPCryptBuffer, 0);
// Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac)
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen);
resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen);
}
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
//
catch (Exception e)
{
_error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}");
error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}");
return new ArraySegment<byte>();
}
#if UNITY_EDITOR
// expecting the result length to match the previously calculated input size + MacSize
if (resultLen != outSize)
{
throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
}
#endif
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen);
return new ArraySegment<byte>(TMPCryptBuffer, 0, resultLen);
}
private ArraySegment<byte> Decrypt(ArraySegment<byte> ciphertext)
ArraySegment<byte> Decrypt(ArraySegment<byte> ciphertext)
{
if (ciphertext.Count <= MacSizeBytes)
{
_error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})");
error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})");
// Invalid
return new ArraySegment<byte>();
}
// Re-initialize the cipher with our cached parameters
Cipher.Init(false, _cipherParametersDecrypt);
Cipher.Init(false, cipherParametersDecrypt);
// Calculate the expected output size, this should always be input size - mac size
int outSize = Cipher.GetOutputSize(ciphertext.Count);
#if UNITY_EDITOR
// expecting the outSize to be input size - MacSize
if (outSize != ciphertext.Count - MacSizeBytes)
{
throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}");
}
#endif
// Resize the static buffer to fit
EnsureSize(ref _tmpCryptBuffer, outSize);
EnsureSize(ref TMPCryptBuffer, outSize);
int resultLen;
try
{
// Run the ciphertext through the cipher, ProcessBytes will only process full blocks
resultLen =
Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, _tmpCryptBuffer, 0);
Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, TMPCryptBuffer, 0);
// Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac)
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen);
resultLen += Cipher.DoFinal(TMPCryptBuffer, resultLen);
}
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
catch (Exception e)
{
_error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data");
error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data");
return new ArraySegment<byte>();
}
#if UNITY_EDITOR
// expecting the result length to match the previously calculated input size + MacSize
if (resultLen != outSize)
{
throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
}
#endif
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen);
return new ArraySegment<byte>(TMPCryptBuffer, 0, resultLen);
}
private void UpdateNonce()
void UpdateNonce()
{
// increment the nonce by one
// we need to ensure the nonce is *always* unique and not reused
// easiest way to do this is by simply incrementing it
for (int i = 0; i < NonceSize; i++)
{
_nonce[i]++;
if (_nonce[i] != 0)
{
nonce[i]++;
if (nonce[i] != 0)
break;
}
}
}
private static void EnsureSize(ref byte[] buffer, int size)
static void EnsureSize(ref byte[] buffer, int size)
{
if (buffer.Length < size)
{
// double buffer to avoid constantly resizing by a few bytes
Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2));
}
}
private void SendHandshakeAndPubKey(OpCodes opcode)
void SendHandshakeAndPubKey(OpCodes opcode)
{
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{
writer.WriteByte((byte)opcode);
if (opcode == OpCodes.HandshakeAck)
{
writer.WriteBytes(_hkdfSalt, 0, HkdfSaltSize);
}
writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length);
_send(writer.ToArraySegment(), Channels.Unreliable);
writer.WriteBytes(hkdfSalt, 0, HkdfSaltSize);
writer.WriteBytes(credentials.PublicKeySerialized, 0, credentials.PublicKeySerialized.Length);
send(writer.ToArraySegment(), Channels.Unreliable);
}
}
private void SendHandshakeFin()
void SendHandshakeFin()
{
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{
writer.WriteByte((byte)OpCodes.HandshakeFin);
_send(writer.ToArraySegment(), Channels.Unreliable);
send(writer.ToArraySegment(), Channels.Unreliable);
}
}
private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
{
AsymmetricKeyParameter remotePubKey;
try
@ -470,11 +435,11 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
}
catch (Exception e)
{
_error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}");
error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}");
return;
}
if (_validateRemoteKey != null)
if (validateRemoteKey != null)
{
PubKeyInfo info = new PubKeyInfo
{
@ -482,9 +447,9 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
Serialized = remotePubKeyRaw,
Key = remotePubKey
};
if (!_validateRemoteKey(info))
if (!validateRemoteKey(info))
{
_error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. ");
error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. ");
return;
}
}
@ -493,7 +458,7 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
// This gives us the same key on the other side, with our public key and their remote
// It's like magic, but with math!
ECDHBasicAgreement ecdh = new ECDHBasicAgreement();
ecdh.Init(_credentials.PrivateKey);
ecdh.Init(credentials.PrivateKey);
byte[] sharedSecret;
try
{
@ -502,13 +467,13 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
catch
(Exception e)
{
_error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}");
error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}");
return;
}
if (salt.Length != HkdfSaltSize)
{
_error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}.");
error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}.");
return;
}
@ -523,12 +488,12 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
KeyParameter key = new KeyParameter(keyRaw);
// generate a starting nonce
_nonce = GenerateSecureBytes(NonceSize);
nonce = GenerateSecureBytes(NonceSize);
// we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance
// instead of creating a new one each encrypt/decrypt
_cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, _nonce);
_cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce);
cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, nonce);
cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce);
}
/**
@ -537,53 +502,41 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
public void TickNonReady(double time)
{
if (IsReady)
return;
// Timeout reset
if (handshakeTimeout == 0)
handshakeTimeout = time + DurationTimeout;
else if (time > handshakeTimeout)
{
error?.Invoke(TransportError.Timeout, $"Timed out during {state}, this probably just means the other side went away which is fine.");
return;
}
// Timeout reset
if (_handshakeTimeout == 0)
if (nextHandshakeResend < 0)
{
_handshakeTimeout = time + DurationTimeout;
}
else if (time > _handshakeTimeout)
{
_error?.Invoke(TransportError.Timeout, $"Timed out during {_state}, this probably just means the other side went away which is fine.");
nextHandshakeResend = time + DurationResend;
return;
}
// Timeout reset
if (_nextHandshakeResend < 0)
{
_nextHandshakeResend = time + DurationResend;
return;
}
if (time < _nextHandshakeResend)
{
if (time < nextHandshakeResend)
// Resend isn't due yet
return;
}
_nextHandshakeResend = time + DurationResend;
switch (_state)
nextHandshakeResend = time + DurationResend;
switch (state)
{
case State.WaitingHandshake:
if (_sendsFirst)
{
if (sendsFirst)
SendHandshakeAndPubKey(OpCodes.HandshakeStart);
}
break;
case State.WaitingHandshakeReply:
if (_sendsFirst)
{
if (sendsFirst)
SendHandshakeFin();
}
else
{
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
}
break;
case State.Ready: // IsReady is checked above & early-returned

View File

@ -51,13 +51,11 @@ public static byte[] SerializePublicKey(AsymmetricKeyParameter publicKey)
return publicKeyInfo.ToAsn1Object().GetDerEncoded();
}
public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment<byte> pubKey)
{
public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment<byte> pubKey) =>
// And then we do this to deserialize from the DER (from above)
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
// to a byte[] first and then shoved through a MemoryStream
return PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false));
}
PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false));
public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
{
@ -66,13 +64,11 @@ public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
return privateKeyInfo.ToAsn1Object().GetDerEncoded();
}
public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment<byte> privateKey)
{
public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment<byte> privateKey) =>
// And then we do this to deserialize from the DER (from above)
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
// to a byte[] first and then shoved through a MemoryStream
return PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false));
}
PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false));
public static string PubKeyFingerprint(ArraySegment<byte> publicKeyBytes)
{
@ -90,7 +86,7 @@ public void SaveToFile(string path)
{
PublicKeyFingerprint = PublicKeyFingerprint,
PublicKey = Convert.ToBase64String(PublicKeySerialized),
PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)),
PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey))
});
File.WriteAllText(path, json);
}
@ -104,9 +100,7 @@ public static EncryptionCredentials LoadFromFile(string path)
byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey);
if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment<byte>(publicKeyBytes)))
{
throw new Exception("Saved public key fingerprint does not match public key.");
}
return new EncryptionCredentials
{
PublicKeySerialized = publicKeyBytes,
@ -115,7 +109,7 @@ public static EncryptionCredentials LoadFromFile(string path)
};
}
private class SerializedPair
class SerializedPair
{
public string PublicKeyFingerprint;
public string PublicKey;

View File

@ -12,28 +12,27 @@ public class EncryptionTransport : Transport, PortTransport
{
public override bool IsEncrypted => true;
public override string EncryptionCipher => "AES256-GCM";
public Transport inner;
[FormerlySerializedAs("inner")]
public Transport Inner;
public ushort Port
{
get
{
if (inner is PortTransport portTransport)
{
if (Inner is PortTransport portTransport)
return portTransport.Port;
}
Debug.LogError($"EncryptionTransport can't get Port because {inner} is not a PortTransport");
Debug.LogError($"EncryptionTransport can't get Port because {Inner} is not a PortTransport");
return 0;
}
set
{
if (inner is PortTransport portTransport)
if (Inner is PortTransport portTransport)
{
portTransport.Port = value;
return;
}
Debug.LogError($"EncryptionTransport can't set Port because {inner} is not a PortTransport");
Debug.LogError($"EncryptionTransport can't set Port because {Inner} is not a PortTransport");
}
}
@ -41,81 +40,78 @@ public enum ValidationMode
{
Off,
List,
Callback,
Callback
}
public ValidationMode clientValidateServerPubKey;
[FormerlySerializedAs("clientValidateServerPubKey")]
public ValidationMode ClientValidateServerPubKey;
[FormerlySerializedAs("clientTrustedPubKeySignatures")]
[Tooltip("List of public key fingerprints the client will accept")]
public string[] clientTrustedPubKeySignatures;
public Func<PubKeyInfo, bool> onClientValidateServerPubKey;
public bool serverLoadKeyPairFromFile;
public string serverKeypairPath = "./server-keys.json";
public string[] ClientTrustedPubKeySignatures;
public Func<PubKeyInfo, bool> OnClientValidateServerPubKey;
[FormerlySerializedAs("serverLoadKeyPairFromFile")]
public bool ServerLoadKeyPairFromFile;
[FormerlySerializedAs("serverKeypairPath")]
public string ServerKeypairPath = "./server-keys.json";
private EncryptedConnection _client;
EncryptedConnection client;
private Dictionary<int, EncryptedConnection> _serverConnections = new Dictionary<int, EncryptedConnection>();
readonly Dictionary<int, EncryptedConnection> serverConnections = new Dictionary<int, EncryptedConnection>();
private List<EncryptedConnection> _serverPendingConnections =
readonly List<EncryptedConnection> serverPendingConnections =
new List<EncryptedConnection>();
private EncryptionCredentials _credentials;
public string EncryptionPublicKeyFingerprint => _credentials?.PublicKeyFingerprint;
public byte[] EncryptionPublicKey => _credentials?.PublicKeySerialized;
EncryptionCredentials credentials;
public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint;
public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized;
private void ServerRemoveFromPending(EncryptedConnection con)
void ServerRemoveFromPending(EncryptedConnection con)
{
for (int i = 0; i < _serverPendingConnections.Count; i++)
{
if (_serverPendingConnections[i] == con)
for (int i = 0; i < serverPendingConnections.Count; i++)
if (serverPendingConnections[i] == con)
{
// remove by swapping with last
int lastIndex = _serverPendingConnections.Count - 1;
_serverPendingConnections[i] = _serverPendingConnections[lastIndex];
_serverPendingConnections.RemoveAt(lastIndex);
int lastIndex = serverPendingConnections.Count - 1;
serverPendingConnections[i] = serverPendingConnections[lastIndex];
serverPendingConnections.RemoveAt(lastIndex);
break;
}
}
}
private void HandleInnerServerDisconnected(int connId)
void HandleInnerServerDisconnected(int connId)
{
if (_serverConnections.TryGetValue(connId, out EncryptedConnection con))
if (serverConnections.TryGetValue(connId, out EncryptedConnection con))
{
ServerRemoveFromPending(con);
_serverConnections.Remove(connId);
serverConnections.Remove(connId);
}
OnServerDisconnected?.Invoke(connId);
}
private void HandleInnerServerError(int connId, TransportError type, string msg)
{
OnServerError?.Invoke(connId, type, $"inner: {msg}");
}
void HandleInnerServerError(int connId, TransportError type, string msg) => OnServerError?.Invoke(connId, type, $"inner: {msg}");
private void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
{
if (_serverConnections.TryGetValue(connId, out EncryptedConnection c))
{
if (serverConnections.TryGetValue(connId, out EncryptedConnection c))
c.OnReceiveRaw(data, channel);
}
}
private void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, inner.ServerGetClientAddress(connId));
void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId));
private void HandleInnerServerConnected(int connId, string clientRemoteAddress)
void HandleInnerServerConnected(int connId, string clientRemoteAddress)
{
Debug.Log($"[EncryptionTransport] New connection #{connId}");
Debug.Log($"[EncryptionTransport] New connection #{connId} from {clientRemoteAddress}");
EncryptedConnection ec = null;
ec = new EncryptedConnection(
_credentials,
credentials,
false,
(segment, channel) => inner.ServerSend(connId, segment, channel),
(segment, channel) => Inner.ServerSend(connId, segment, channel),
(segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel),
() =>
{
Debug.Log($"[EncryptionTransport] Connection #{connId} is ready");
// ReSharper disable once AccessToModifiedClosure
ServerRemoveFromPending(ec);
//OnServerConnected?.Invoke(connId);
OnServerConnectedWithAddress?.Invoke(connId, clientRemoteAddress);
},
(type, msg) =>
@ -123,32 +119,25 @@ private void HandleInnerServerConnected(int connId, string clientRemoteAddress)
OnServerError?.Invoke(connId, type, msg);
ServerDisconnect(connId);
});
_serverConnections.Add(connId, ec);
_serverPendingConnections.Add(ec);
serverConnections.Add(connId, ec);
serverPendingConnections.Add(ec);
}
private void HandleInnerClientDisconnected()
void HandleInnerClientDisconnected()
{
_client = null;
client = null;
OnClientDisconnected?.Invoke();
}
private void HandleInnerClientError(TransportError arg1, string arg2)
{
OnClientError?.Invoke(arg1, $"inner: {arg2}");
}
void HandleInnerClientError(TransportError arg1, string arg2) => OnClientError?.Invoke(arg1, $"inner: {arg2}");
private void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel)
{
_client?.OnReceiveRaw(data, channel);
}
void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel) => client?.OnReceiveRaw(data, channel);
private void HandleInnerClientConnected()
{
_client = new EncryptedConnection(
_credentials,
void HandleInnerClientConnected() =>
client = new EncryptedConnection(
credentials,
true,
(segment, channel) => inner.ClientSend(segment, channel),
(segment, channel) => Inner.ClientSend(segment, channel),
(segment, channel) => OnClientDataReceived?.Invoke(segment, channel),
() =>
{
@ -160,25 +149,23 @@ private void HandleInnerClientConnected()
ClientDisconnect();
},
HandleClientValidateServerPubKey);
}
private bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
{
switch (clientValidateServerPubKey)
switch (ClientValidateServerPubKey)
{
case ValidationMode.Off:
return true;
case ValidationMode.List:
return Array.IndexOf(clientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
case ValidationMode.Callback:
return onClientValidateServerPubKey(pubKeyInfo);
return OnClientValidateServerPubKey(pubKeyInfo);
default:
throw new ArgumentOutOfRangeException();
}
}
void Awake()
{
void Awake() =>
// check if encryption via hardware acceleration is supported.
// this can be useful to know for low end devices.
//
@ -188,27 +175,26 @@ void Awake()
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs
// which Unity does not support yet.
Debug.Log($"EncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}");
}
public override bool Available() => inner.Available();
public override bool Available() => Inner.Available();
public override bool ClientConnected() => _client != null && _client.IsReady;
public override bool ClientConnected() => client != null && client.IsReady;
public override void ClientConnect(string address)
{
switch (clientValidateServerPubKey)
switch (ClientValidateServerPubKey)
{
case ValidationMode.Off:
break;
case ValidationMode.List:
if (clientTrustedPubKeySignatures == null || clientTrustedPubKeySignatures.Length == 0)
if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0)
{
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty.");
return;
}
break;
case ValidationMode.Callback:
if (onClientValidateServerPubKey == null)
if (OnClientValidateServerPubKey == null)
{
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set");
return;
@ -217,95 +203,79 @@ public override void ClientConnect(string address)
default:
throw new ArgumentOutOfRangeException();
}
_credentials = EncryptionCredentials.Generate();
inner.OnClientConnected = HandleInnerClientConnected;
inner.OnClientDataReceived = HandleInnerClientDataReceived;
inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel);
inner.OnClientError = HandleInnerClientError;
inner.OnClientDisconnected = HandleInnerClientDisconnected;
inner.ClientConnect(address);
credentials = EncryptionCredentials.Generate();
Inner.OnClientConnected = HandleInnerClientConnected;
Inner.OnClientDataReceived = HandleInnerClientDataReceived;
Inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel);
Inner.OnClientError = HandleInnerClientError;
Inner.OnClientDisconnected = HandleInnerClientDisconnected;
Inner.ClientConnect(address);
}
public override void ClientSend(ArraySegment<byte> segment, int channelId = Channels.Reliable) =>
_client?.Send(segment, channelId);
client?.Send(segment, channelId);
public override void ClientDisconnect() => inner.ClientDisconnect();
public override void ClientDisconnect() => Inner.ClientDisconnect();
public override Uri ServerUri() => inner.ServerUri();
public override Uri ServerUri() => Inner.ServerUri();
public override bool ServerActive() => inner.ServerActive();
public override bool ServerActive() => Inner.ServerActive();
public override void ServerStart()
{
if (serverLoadKeyPairFromFile)
{
_credentials = EncryptionCredentials.LoadFromFile(serverKeypairPath);
}
if (ServerLoadKeyPairFromFile)
credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath);
else
{
_credentials = EncryptionCredentials.Generate();
}
credentials = EncryptionCredentials.Generate();
#pragma warning disable CS0618 // Type or member is obsolete
inner.OnServerConnected = HandleInnerServerConnected;
Inner.OnServerConnected = HandleInnerServerConnected;
#pragma warning restore CS0618 // Type or member is obsolete
inner.OnServerConnectedWithAddress = HandleInnerServerConnected;
inner.OnServerDataReceived = HandleInnerServerDataReceived;
inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel);
inner.OnServerError = HandleInnerServerError;
inner.OnServerDisconnected = HandleInnerServerDisconnected;
inner.ServerStart();
Inner.OnServerConnectedWithAddress = HandleInnerServerConnected;
Inner.OnServerDataReceived = HandleInnerServerDataReceived;
Inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel);
Inner.OnServerError = HandleInnerServerError;
Inner.OnServerDisconnected = HandleInnerServerDisconnected;
Inner.ServerStart();
}
public override void ServerSend(int connectionId, ArraySegment<byte> segment, int channelId = Channels.Reliable)
{
if (_serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
{
if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
connection.Send(segment, channelId);
}
}
public override void ServerDisconnect(int connectionId)
{
public override void ServerDisconnect(int connectionId) =>
// cleanup is done via inners disconnect event
inner.ServerDisconnect(connectionId);
}
Inner.ServerDisconnect(connectionId);
public override string ServerGetClientAddress(int connectionId) => inner.ServerGetClientAddress(connectionId);
public override string ServerGetClientAddress(int connectionId) => Inner.ServerGetClientAddress(connectionId);
public override void ServerStop() => inner.ServerStop();
public override void ServerStop() => Inner.ServerStop();
public override int GetMaxPacketSize(int channelId = Channels.Reliable) =>
inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
public override void Shutdown() => inner.Shutdown();
public override void Shutdown() => Inner.Shutdown();
public override void ClientEarlyUpdate()
{
inner.ClientEarlyUpdate();
}
public override void ClientEarlyUpdate() => Inner.ClientEarlyUpdate();
public override void ClientLateUpdate()
{
inner.ClientLateUpdate();
Inner.ClientLateUpdate();
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
_client?.TickNonReady(NetworkTime.localTime);
client?.TickNonReady(NetworkTime.localTime);
Profiler.EndSample();
}
public override void ServerEarlyUpdate()
{
inner.ServerEarlyUpdate();
}
public override void ServerEarlyUpdate() => Inner.ServerEarlyUpdate();
public override void ServerLateUpdate()
{
inner.ServerLateUpdate();
Inner.ServerLateUpdate();
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
// Reverse iteration as entries can be removed while updating
for (int i = _serverPendingConnections.Count - 1; i >= 0; i--)
{
_serverPendingConnections[i].TickNonReady(NetworkTime.time);
}
for (int i = serverPendingConnections.Count - 1; i >= 0; i--)
serverPendingConnections[i].TickNonReady(NetworkTime.time);
Profiler.EndSample();
}
}