Compare commits

...

5 Commits

Author SHA1 Message Date
Robin Rolf
f7c38d6958
Merge c5824f39e1 into 1187a59b18 2024-11-11 04:55:40 +02:00
Robin Rolf
1187a59b18
fix: NetworkIdentity component bitmask shifting overflows (#3941)
Some checks failed
Main / Run Unity Tests (push) Has been cancelled
Main / Delete Old Workflow Runs (push) Has been cancelled
Main / Semantic Release (push) Has been cancelled
* Failing test for netid bitmask shifting

* Fix netid bitmask shifting typing

Otherwise it uses ints and overflows when shifting more than 32
2024-11-09 17:45:35 +01:00
mischa
499e4daea3
feat: NetworkTransformHybrid - Hybrid Sync Part 1 (#3937)
Some checks failed
Main / Run Unity Tests (push) Has been cancelled
Main / Delete Old Workflow Runs (push) Has been cancelled
Main / Semantic Release (push) Has been cancelled
* hybrid nt

* fix mrg grid issue; fix unreliable sending just because baseline changed

* comment onserialize baseline

* Update Assets/Mirror/Components/NetworkTransform/NetworkTransformHybrid2022.cs

Co-authored-by: MrGadget <9826063+MrGadget1024@users.noreply.github.com>

* Update Assets/Mirror/Components/NetworkTransform/NetworkTransformHybrid2022.cs

Co-authored-by: MrGadget <9826063+MrGadget1024@users.noreply.github.com>

* Update Assets/Mirror/Components/NetworkTransform/NetworkTransformHybrid2022.cs

Co-authored-by: MrGadget <9826063+MrGadget1024@users.noreply.github.com>

* Update Assets/Mirror/Components/NetworkTransform/NetworkTransformHybrid2022.cs

Co-authored-by: MrGadget <9826063+MrGadget1024@users.noreply.github.com>

* Update Assets/Mirror/Components/NetworkTransform/NetworkTransformHybrid2022.cs

Co-authored-by: MrGadget <9826063+MrGadget1024@users.noreply.github.com>

* Update Assets/Mirror/Components/NetworkTransform/NetworkTransformHybrid2022.cs

Co-authored-by: MrGadget <9826063+MrGadget1024@users.noreply.github.com>

* nthybrid: debug draw data points

* debug draw: drops

* nthybrid: OnServerToClient checks for host mode first to avoid noise!

* nthybrid: OnClientToServer check ordering

* fix: don't apply any hybrid rpcs in host mode, fixes overwriting client's data points

* icon

* syntax

* don't pass Vector3? and QUaternion?

* remove unused

* comments

* cleanup

* cleanups

* comments

* cleanup: remove ContructSnapshot

* syncScale

* comment

* remove custom change

---------

Co-authored-by: MrGadget <9826063+MrGadget1024@users.noreply.github.com>
2024-11-07 17:59:34 +01:00
Robin Rolf
c5824f39e1 feat: ThreadedEncryptionTransport 2024-10-26 17:58:24 +00:00
Robin Rolf
7ebbaa0319 code style 2024-10-26 17:58:20 +00:00
12 changed files with 2263 additions and 331 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 8f63ea2e505fd484193fb31c5c55ca73
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -863,7 +863,7 @@ internal void OnStopLocalPlayer()
for (int i = 0; i < components.Length; ++i) for (int i = 0; i < components.Length; ++i)
{ {
NetworkBehaviour component = components[i]; NetworkBehaviour component = components[i];
ulong nthBit = (1u << i); ulong nthBit = 1ul << i;
bool dirty = component.IsDirty(); bool dirty = component.IsDirty();
@ -910,7 +910,7 @@ ulong ClientDirtyMask()
// on client, only consider owned components with SyncDirection to server // on client, only consider owned components with SyncDirection to server
NetworkBehaviour component = components[i]; NetworkBehaviour component = components[i];
ulong nthBit = (1u << i); ulong nthBit = 1ul << i;
if (isOwned && component.syncDirection == SyncDirection.ClientToServer) if (isOwned && component.syncDirection == SyncDirection.ClientToServer)
{ {
@ -928,7 +928,7 @@ ulong ClientDirtyMask()
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool IsDirty(ulong mask, int index) internal static bool IsDirty(ulong mask, int index)
{ {
ulong nthBit = (ulong)(1 << index); ulong nthBit = 1ul << index;
return (mask & nthBit) != 0; return (mask & nthBit) != 0;
} }

View File

@ -1,4 +1,5 @@
// base class for networking tests to make things easier. // base class for networking tests to make things easier.
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using NUnit.Framework; using NUnit.Framework;
@ -81,6 +82,18 @@ protected void CreateNetworked(out GameObject go, out NetworkIdentity identity)
instantiated.Add(go); instantiated.Add(go);
} }
protected void CreateNetworked(out GameObject go, out NetworkIdentity identity, Action<NetworkIdentity> preAwake)
{
go = new GameObject();
identity = go.AddComponent<NetworkIdentity>();
preAwake(identity);
// Awake is only called in play mode.
// call manually for initialization.
identity.Awake();
// track
instantiated.Add(go);
}
// create GameObject + NetworkIdentity + NetworkBehaviour<T> // create GameObject + NetworkIdentity + NetworkBehaviour<T>
// add to tracker list if needed (useful for cleanups afterwards) // add to tracker list if needed (useful for cleanups afterwards)
protected void CreateNetworked<T>(out GameObject go, out NetworkIdentity identity, out T component) protected void CreateNetworked<T>(out GameObject go, out NetworkIdentity identity, out T component)
@ -269,6 +282,44 @@ protected void CreateNetworkedAndSpawn(
Assert.That(NetworkClient.spawned.ContainsKey(serverIdentity.netId)); Assert.That(NetworkClient.spawned.ContainsKey(serverIdentity.netId));
} }
// create GameObject + NetworkIdentity + NetworkBehaviour & SPAWN
// => preAwake callbacks can be used to add network behaviours to the NI
// => ownerConnection can be NetworkServer.localConnection if needed.
// => returns objects from client and from server.
// will be same in host mode.
protected void CreateNetworkedAndSpawn(
out GameObject serverGO, out NetworkIdentity serverIdentity, Action<NetworkIdentity> serverPreAwake,
out GameObject clientGO, out NetworkIdentity clientIdentity, Action<NetworkIdentity> clientPreAwake,
NetworkConnectionToClient ownerConnection = null)
{
// server & client need to be active before spawning
Debug.Assert(NetworkClient.active, "NetworkClient needs to be active before spawning.");
Debug.Assert(NetworkServer.active, "NetworkServer needs to be active before spawning.");
// create one on server, one on client
// (spawning has to find it on client, it doesn't create it)
CreateNetworked(out serverGO, out serverIdentity, serverPreAwake);
CreateNetworked(out clientGO, out clientIdentity, clientPreAwake);
// give both a scene id and register it on client for spawnables
clientIdentity.sceneId = serverIdentity.sceneId = (ulong)serverGO.GetHashCode();
NetworkClient.spawnableObjects[clientIdentity.sceneId] = clientIdentity;
// spawn
NetworkServer.Spawn(serverGO, ownerConnection);
ProcessMessages();
// double check isServer/isClient. avoids debugging headaches.
Assert.That(serverIdentity.isServer, Is.True);
Assert.That(clientIdentity.isClient, Is.True);
// double check that we have authority if we passed an owner connection
if (ownerConnection != null)
Debug.Assert(clientIdentity.isOwned == true, $"Behaviour Had Wrong Authority when spawned, This means that the test is broken and will give the wrong results");
// make sure the client really spawned it.
Assert.That(NetworkClient.spawned.ContainsKey(serverIdentity.netId));
}
// create GameObject + NetworkIdentity + NetworkBehaviour & SPAWN // create GameObject + NetworkIdentity + NetworkBehaviour & SPAWN
// => ownerConnection can be NetworkServer.localConnection if needed. // => ownerConnection can be NetworkServer.localConnection if needed.
protected void CreateNetworkedAndSpawn<T>(out GameObject go, out NetworkIdentity identity, out T component, NetworkConnectionToClient ownerConnection = null) protected void CreateNetworkedAndSpawn<T>(out GameObject go, out NetworkIdentity identity, out T component, NetworkConnectionToClient ownerConnection = null)

View File

@ -1,4 +1,5 @@
// OnDe/SerializeSafely tests. // OnDe/SerializeSafely tests.
using System.Collections.Generic;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using Mirror.Tests.EditorBehaviours.NetworkIdentities; using Mirror.Tests.EditorBehaviours.NetworkIdentities;
using NUnit.Framework; using NUnit.Framework;
@ -42,7 +43,7 @@ public void SerializeAndDeserializeAll()
); );
// set sync modes // set sync modes
serverOwnerComp.syncMode = clientOwnerComp.syncMode = SyncMode.Owner; serverOwnerComp.syncMode = clientOwnerComp.syncMode = SyncMode.Owner;
serverObserversComp.syncMode = clientObserversComp.syncMode = SyncMode.Observers; serverObserversComp.syncMode = clientObserversComp.syncMode = SyncMode.Observers;
// set unique values on server components // set unique values on server components
@ -65,10 +66,127 @@ public void SerializeAndDeserializeAll()
// deserialize client object with OBSERVERS payload // deserialize client object with OBSERVERS payload
reader = new NetworkReader(observersWriter.ToArray()); reader = new NetworkReader(observersWriter.ToArray());
clientIdentity.DeserializeClient(reader, true); clientIdentity.DeserializeClient(reader, true);
Assert.That(clientOwnerComp.value, Is.EqualTo(null)); // owner mode shouldn't be in data Assert.That(clientOwnerComp.value, Is.EqualTo(null)); // owner mode shouldn't be in data
Assert.That(clientObserversComp.value, Is.EqualTo(42)); // observers mode should be in data Assert.That(clientObserversComp.value, Is.EqualTo(42)); // observers mode should be in data
} }
// test serialize -> deserialize of any supported number of components
[Test]
public void SerializeAndDeserializeN([NUnit.Framework.Range(1, 64)] int numberOfNBs)
{
List<SerializeTest1NetworkBehaviour> serverNBs = new List<SerializeTest1NetworkBehaviour>();
List<SerializeTest1NetworkBehaviour> clientNBs = new List<SerializeTest1NetworkBehaviour>();
// need two of both versions so we can serialize -> deserialize
CreateNetworkedAndSpawn(
out _, out NetworkIdentity serverIdentity, ni =>
{
for (int i = 0; i < numberOfNBs; i++)
{
SerializeTest1NetworkBehaviour nb = ni.gameObject.AddComponent<SerializeTest1NetworkBehaviour>();
nb.syncInterval = 0;
nb.syncMode = SyncMode.Observers;
serverNBs.Add(nb);
}
},
out _, out NetworkIdentity clientIdentity, ni =>
{
for (int i = 0; i < numberOfNBs; i++)
{
SerializeTest1NetworkBehaviour nb = ni.gameObject.AddComponent<SerializeTest1NetworkBehaviour>();
nb.syncInterval = 0;
nb.syncMode = SyncMode.Observers;
clientNBs.Add(nb);
}
}
);
// INITIAL SYNC
// set unique values on server components
for (int i = 0; i < serverNBs.Count; i++)
{
serverNBs[i].value = (i + 1) * 3;
serverNBs[i].SetDirty();
}
// serialize server object
serverIdentity.SerializeServer(true, ownerWriter, observersWriter);
// deserialize client object with OBSERVERS payload
NetworkReader reader = new NetworkReader(observersWriter.ToArray());
clientIdentity.DeserializeClient(reader, true);
for (int i = 0; i < clientNBs.Count; i++)
{
int expected = (i + 1) * 3;
Assert.That(clientNBs[i].value, Is.EqualTo(expected), $"Expected the clientNBs[{i}] to have a value of {expected}");
}
// clear dirty bits for incremental sync
foreach (SerializeTest1NetworkBehaviour serverNB in serverNBs)
serverNB.ClearAllDirtyBits();
// INCREMENTAL SYNC ALL
// set unique values on server components
for (int i = 0; i < serverNBs.Count; i++)
{
serverNBs[i].value = (i + 1) * 11;
serverNBs[i].SetDirty();
}
ownerWriter.Reset();
observersWriter.Reset();
// serialize server object
serverIdentity.SerializeServer(false, ownerWriter, observersWriter);
// deserialize client object with OBSERVERS payload
reader = new NetworkReader(observersWriter.ToArray());
clientIdentity.DeserializeClient(reader, false);
for (int i = 0; i < clientNBs.Count; i++)
{
int expected = (i + 1) * 11;
Assert.That(clientNBs[i].value, Is.EqualTo(expected), $"Expected the clientNBs[{i}] to have a value of {expected}");
}
// clear dirty bits for incremental sync
foreach (SerializeTest1NetworkBehaviour serverNB in serverNBs)
serverNB.ClearAllDirtyBits();
// INCREMENTAL SYNC INDIVIDUAL
for (int i = 0; i < numberOfNBs; i++)
{
// reset all client nbs
foreach (SerializeTest1NetworkBehaviour clientNB in clientNBs)
clientNB.value = 0;
int expected = (i + 1) * 7;
// set unique value on server components
serverNBs[i].value = expected;
serverNBs[i].SetDirty();
ownerWriter.Reset();
observersWriter.Reset();
// serialize server object
serverIdentity.SerializeServer(false, ownerWriter, observersWriter);
// deserialize client object with OBSERVERS payload
reader = new NetworkReader(observersWriter.ToArray());
clientIdentity.DeserializeClient(reader, false);
for (int index = 0; index < clientNBs.Count; index++)
{
SerializeTest1NetworkBehaviour clientNB = clientNBs[index];
if (index == i)
{
Assert.That(clientNB.value, Is.EqualTo(expected), $"Expected the clientNBs[{index}] to have a value of {expected}");
}
else
{
Assert.That(clientNB.value, Is.EqualTo(0), $"Expected the clientNBs[{index}] to have a value of 0 since we're not syncing that index (on sync of #{i})");
}
}
}
}
// serialization should work even if a component throws an exception. // serialization should work even if a component throws an exception.
// so if first component throws, second should still be serialized fine. // so if first component throws, second should still be serialized fine.
[Test] [Test]
@ -150,20 +268,20 @@ public void TooManyComponents()
public void ErrorCorrection() public void ErrorCorrection()
{ {
int original = 0x12345678; int original = 0x12345678;
byte safety = 0x78; // last byte byte safety = 0x78; // last byte
// correct size shouldn't be corrected // correct size shouldn't be corrected
Assert.That(NetworkBehaviour.ErrorCorrection(original + 0, safety), Is.EqualTo(original)); Assert.That(NetworkBehaviour.ErrorCorrection(original + 0, safety), Is.EqualTo(original));
// read a little too much // read a little too much
Assert.That(NetworkBehaviour.ErrorCorrection(original + 1, safety), Is.EqualTo(original)); Assert.That(NetworkBehaviour.ErrorCorrection(original + 1, safety), Is.EqualTo(original));
Assert.That(NetworkBehaviour.ErrorCorrection(original + 2, safety), Is.EqualTo(original)); Assert.That(NetworkBehaviour.ErrorCorrection(original + 2, safety), Is.EqualTo(original));
Assert.That(NetworkBehaviour.ErrorCorrection(original + 42, safety), Is.EqualTo(original)); Assert.That(NetworkBehaviour.ErrorCorrection(original + 42, safety), Is.EqualTo(original));
// read a little too less // read a little too less
Assert.That(NetworkBehaviour.ErrorCorrection(original - 1, safety), Is.EqualTo(original)); Assert.That(NetworkBehaviour.ErrorCorrection(original - 1, safety), Is.EqualTo(original));
Assert.That(NetworkBehaviour.ErrorCorrection(original - 2, safety), Is.EqualTo(original)); Assert.That(NetworkBehaviour.ErrorCorrection(original - 2, safety), Is.EqualTo(original));
Assert.That(NetworkBehaviour.ErrorCorrection(original - 42, safety), Is.EqualTo(original)); Assert.That(NetworkBehaviour.ErrorCorrection(original - 42, safety), Is.EqualTo(original));
// reading way too much / less is expected to fail. // reading way too much / less is expected to fail.
// we can only correct the last byte, not more. // we can only correct the last byte, not more.

View File

@ -22,7 +22,7 @@ public void Setup()
GameObject gameObject = new GameObject(); GameObject gameObject = new GameObject();
encryption = gameObject.AddComponent<EncryptionTransport>(); encryption = gameObject.AddComponent<EncryptionTransport>();
encryption.inner = inner; encryption.Inner = inner;
} }
[TearDown] [TearDown]

View File

@ -17,11 +17,11 @@ public class EncryptionTransportInspector : UnityEditor.Editor
void OnEnable() void OnEnable()
{ {
innerProperty = serializedObject.FindProperty("inner"); innerProperty = serializedObject.FindProperty("Inner");
clientValidatesServerPubKeyProperty = serializedObject.FindProperty("clientValidateServerPubKey"); clientValidatesServerPubKeyProperty = serializedObject.FindProperty("ClientValidateServerPubKey");
clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("clientTrustedPubKeySignatures"); clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("ClientTrustedPubKeySignatures");
serverKeypairPathProperty = serializedObject.FindProperty("serverKeypairPath"); serverKeypairPathProperty = serializedObject.FindProperty("ServerKeypairPath");
serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("serverLoadKeyPairFromFile"); serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("ServerLoadKeyPairFromFile");
} }
public override void OnInspectorGUI() public override void OnInspectorGUI()
@ -77,5 +77,8 @@ public override void OnInspectorGUI()
serializedObject.ApplyModifiedProperties(); serializedObject.ApplyModifiedProperties();
} }
[CustomEditor(typeof(ThreadedEncryptionTransport), true)]
class EncryptionThreadedTransportInspector : EncryptionTransportInspector {}
} }
} }

View File

@ -1,6 +1,7 @@
using System; using System;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Threading;
using Mirror.BouncyCastle.Crypto; using Mirror.BouncyCastle.Crypto;
using Mirror.BouncyCastle.Crypto.Agreement; using Mirror.BouncyCastle.Crypto.Agreement;
using Mirror.BouncyCastle.Crypto.Digests; using Mirror.BouncyCastle.Crypto.Digests;
@ -14,32 +15,32 @@ namespace Mirror.Transports.Encryption
public class EncryptedConnection public class EncryptedConnection
{ {
// 256-bit key // 256-bit key
private const int KeyLength = 32; const int KeyLength = 32;
// 512-bit salt for the key derivation function // 512-bit salt for the key derivation function
private const int HkdfSaltSize = KeyLength * 2; const int HkdfSaltSize = KeyLength * 2;
// Info tag for the HKDF, this just adds more entropy // Info tag for the HKDF, this just adds more entropy
private static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport"); static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport");
// fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed) // fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed)
private const int NonceSize = 12; const int NonceSize = 12;
// this is the size of the "checksum" included in each encrypted payload // this is the size of the "checksum" included in each encrypted payload
// 16 bytes/128 bytes is the recommended value for best security // 16 bytes/128 bytes is the recommended value for best security
// can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker. // can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker.
// Setting it lower than 12 bytes is not recommended // Setting it lower than 12 bytes is not recommended
private const int MacSizeBytes = 16; const int MacSizeBytes = 16;
private const int MacSizeBits = MacSizeBytes * 8; const int MacSizeBits = MacSizeBytes * 8;
// How much metadata overhead we have for regular packets // How much metadata overhead we have for regular packets
public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize; public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize;
// After how many seconds of not receiving a handshake packet we should time out // After how many seconds of not receiving a handshake packet we should time out
private const double DurationTimeout = 2; // 2s const double DurationTimeout = 2; // 2s
// After how many seconds to assume the last handshake packet got lost and to resend another one // After how many seconds to assume the last handshake packet got lost and to resend another one
private const double DurationResend = 0.05; // 50ms const double DurationResend = 0.05; // 50ms
// Static fields for allocation efficiency, makes this not thread safe // Static fields for allocation efficiency, makes this not thread safe
@ -48,20 +49,19 @@ public class EncryptedConnection
// Set up a global cipher instance, it is initialised/reset before use // Set up a global cipher instance, it is initialised/reset before use
// (AesFastEngine used to exist, but was removed due to side channel issues) // (AesFastEngine used to exist, but was removed due to side channel issues)
// use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core) // use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core)
private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); static readonly ThreadLocal<GcmBlockCipher> Cipher = new ThreadLocal<GcmBlockCipher>(() => new GcmBlockCipher(AesUtilities.CreateEngine()));
// Set up a global HKDF with a SHA-256 digest // Set up a global HKDF with a SHA-256 digest
private static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest()); static readonly ThreadLocal<HkdfBytesGenerator> Hkdf = new ThreadLocal<HkdfBytesGenerator>(() => new HkdfBytesGenerator(new Sha256Digest()));
// Global byte array to store nonce sent by the remote side, they're used immediately after // Global byte array to store nonce sent by the remote side, they're used immediately after
private static readonly byte[] ReceiveNonce = new byte[NonceSize]; static readonly ThreadLocal<byte[]> ReceiveNonce = new ThreadLocal<byte[]>(() => new byte[NonceSize]);
// Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes* // Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes*
private static byte[] _tmpRemoteSaltBuffer = new byte[HkdfSaltSize]; static readonly ThreadLocal<byte[]> TMPRemoteSaltBuffer = new ThreadLocal<byte[]>(() => new byte[HkdfSaltSize]);
// buffer for encrypt/decrypt operations, resized larger as needed // buffer for encrypt/decrypt operations, resized larger as needed
// this is also the buffer that will be returned to mirror via ArraySegment static ThreadLocal<byte[]> TMPCryptBuffer = new ThreadLocal<byte[]>(() => new byte[2048]);
// so any thread safety concerns would need to take extra care here
private static byte[] _tmpCryptBuffer = new byte[2048];
// packet headers // packet headers
enum OpCodes : byte enum OpCodes : byte
@ -70,7 +70,7 @@ enum OpCodes : byte
Data = 1, Data = 1,
HandshakeStart = 2, HandshakeStart = 2,
HandshakeAck = 3, HandshakeAck = 3,
HandshakeFin = 4, HandshakeFin = 4
} }
enum State enum State
@ -91,37 +91,37 @@ enum State
Ready Ready
} }
private State _state = State.WaitingHandshake; State state = State.WaitingHandshake;
// Key exchange confirmed and data can be sent freely // Key exchange confirmed and data can be sent freely
public bool IsReady => _state == State.Ready; public bool IsReady => state == State.Ready;
// Callback to send off encrypted data // Callback to send off encrypted data
private Action<ArraySegment<byte>, int> _send; readonly Action<ArraySegment<byte>, int> send;
// Callback when received data has been decrypted // Callback when received data has been decrypted
private Action<ArraySegment<byte>, int> _receive; readonly Action<ArraySegment<byte>, int> receive;
// Callback when the connection becomes ready // Callback when the connection becomes ready
private Action _ready; readonly Action ready;
// On-error callback, disconnect expected // On-error callback, disconnect expected
private Action<TransportError, string> _error; readonly Action<TransportError, string> error;
// Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance // Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance
// (usually client validates the server key) // (usually client validates the server key)
private Func<PubKeyInfo, bool> _validateRemoteKey; readonly Func<PubKeyInfo, bool> validateRemoteKey;
// Our asymmetric credentials for the initial DH exchange // Our asymmetric credentials for the initial DH exchange
private EncryptionCredentials _credentials; EncryptionCredentials credentials;
private byte[] _hkdfSalt; readonly byte[] hkdfSalt;
// After no handshake packet in this many seconds, the handshake fails // After no handshake packet in this many seconds, the handshake fails
private double _handshakeTimeout; double handshakeTimeout;
// When to assume the last handshake packet got lost and to resend another one // When to assume the last handshake packet got lost and to resend another one
private double _nextHandshakeResend; double nextHandshakeResend;
// we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in // we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in
// so we can update it without creating a new AeadParameters instance // so we can update it without creating a new AeadParameters instance
// this might break in the future! (will cause bad data) // this might break in the future! (will cause bad data)
private byte[] _nonce = new byte[NonceSize]; byte[] nonce = new byte[NonceSize];
private AeadParameters _cipherParametersEncrypt; AeadParameters cipherParametersEncrypt;
private AeadParameters _cipherParametersDecrypt; AeadParameters cipherParametersDecrypt;
/* /*
@ -130,7 +130,7 @@ enum State
* *
* The client does this, since the fin is not acked explicitly, but by receiving data to decrypt * The client does this, since the fin is not acked explicitly, but by receiving data to decrypt
*/ */
private readonly bool _sendsFirst; readonly bool sendsFirst;
public EncryptedConnection(EncryptionCredentials credentials, public EncryptedConnection(EncryptionCredentials credentials,
bool isClient, bool isClient,
@ -140,28 +140,24 @@ public EncryptedConnection(EncryptionCredentials credentials,
Action<TransportError, string> errorAction, Action<TransportError, string> errorAction,
Func<PubKeyInfo, bool> validateRemoteKey = null) Func<PubKeyInfo, bool> validateRemoteKey = null)
{ {
_credentials = credentials; this.credentials = credentials;
_sendsFirst = isClient; sendsFirst = isClient;
if (!_sendsFirst) if (!sendsFirst)
{
// salt is controlled by the server // salt is controlled by the server
_hkdfSalt = GenerateSecureBytes(HkdfSaltSize); hkdfSalt = GenerateSecureBytes(HkdfSaltSize);
} send = sendAction;
_send = sendAction; receive = receiveAction;
_receive = receiveAction; ready = readyAction;
_ready = readyAction; error = errorAction;
_error = errorAction; this.validateRemoteKey = validateRemoteKey;
_validateRemoteKey = validateRemoteKey;
} }
// Generates a random starting nonce // Generates a random starting nonce
private static byte[] GenerateSecureBytes(int size) static byte[] GenerateSecureBytes(int size)
{ {
byte[] bytes = new byte[size]; byte[] bytes = new byte[size];
using (RandomNumberGenerator rng = RandomNumberGenerator.Create()) using (RandomNumberGenerator rng = RandomNumberGenerator.Create())
{
rng.GetBytes(bytes); rng.GetBytes(bytes);
}
return bytes; return bytes;
} }
@ -170,7 +166,7 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
{ {
if (data.Count < 1) if (data.Count < 1)
{ {
_error(TransportError.Unexpected, "Received empty packet"); error(TransportError.Unexpected, "Received empty packet");
return; return;
} }
@ -181,94 +177,80 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
{ {
case OpCodes.Data: case OpCodes.Data:
// first sender ready is implicit when data is received // first sender ready is implicit when data is received
if (_sendsFirst && _state == State.WaitingHandshakeReply) if (sendsFirst && state == State.WaitingHandshakeReply)
{
SetReady(); SetReady();
}
else if (!IsReady) else if (!IsReady)
{ error(TransportError.Unexpected, "Unexpected data while not ready.");
_error(TransportError.Unexpected, "Unexpected data while not ready.");
}
if (reader.Remaining < Overhead) if (reader.Remaining < Overhead)
{ {
_error(TransportError.Unexpected, "received data packet smaller than metadata size"); error(TransportError.Unexpected, "received data packet smaller than metadata size");
return; return;
} }
ArraySegment<byte> ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize); ArraySegment<byte> ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize);
reader.ReadBytes(ReceiveNonce, NonceSize); reader.ReadBytes(ReceiveNonce.Value, NonceSize);
Profiler.BeginSample("EncryptedConnection.Decrypt"); Profiler.BeginSample("EncryptedConnection.Decrypt");
ArraySegment<byte> plaintext = Decrypt(ciphertext); ArraySegment<byte> plaintext = Decrypt(ciphertext);
Profiler.EndSample(); Profiler.EndSample();
if (plaintext.Count == 0) if (plaintext.Count == 0)
{
// error // error
return; return;
} receive(plaintext, channel);
_receive(plaintext, channel);
break; break;
case OpCodes.HandshakeStart: case OpCodes.HandshakeStart:
if (_sendsFirst) if (sendsFirst)
{ {
_error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this."); error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this.");
return; return;
} }
if (_state == State.WaitingHandshakeReply) if (state == State.WaitingHandshakeReply)
{
// this is fine, packets may arrive out of order // this is fine, packets may arrive out of order
return; return;
}
_state = State.WaitingHandshakeReply; state = State.WaitingHandshakeReply;
ResetTimeouts(); ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _hkdfSalt); CompleteExchange(reader.ReadBytesSegment(reader.Remaining), hkdfSalt);
SendHandshakeAndPubKey(OpCodes.HandshakeAck); SendHandshakeAndPubKey(OpCodes.HandshakeAck);
break; break;
case OpCodes.HandshakeAck: case OpCodes.HandshakeAck:
if (!_sendsFirst) if (!sendsFirst)
{ {
_error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this."); error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this.");
return; return;
} }
if (IsReady) if (IsReady)
{
// this is fine, packets may arrive out of order // this is fine, packets may arrive out of order
return; return;
}
if (_state == State.WaitingHandshakeReply) if (state == State.WaitingHandshakeReply)
{
// this is fine, packets may arrive out of order // this is fine, packets may arrive out of order
return; return;
}
_state = State.WaitingHandshakeReply; state = State.WaitingHandshakeReply;
ResetTimeouts(); ResetTimeouts();
reader.ReadBytes(_tmpRemoteSaltBuffer, HkdfSaltSize); reader.ReadBytes(TMPRemoteSaltBuffer.Value, HkdfSaltSize);
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _tmpRemoteSaltBuffer); CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer.Value);
SendHandshakeFin(); SendHandshakeFin();
break; break;
case OpCodes.HandshakeFin: case OpCodes.HandshakeFin:
if (_sendsFirst) if (sendsFirst)
{ {
_error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this."); error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this.");
return; return;
} }
if (IsReady) if (IsReady)
{
// this is fine, packets may arrive out of order // this is fine, packets may arrive out of order
return; return;
}
if (_state != State.WaitingHandshakeReply) if (state != State.WaitingHandshakeReply)
{ {
_error(TransportError.Unexpected, error(TransportError.Unexpected,
"Received HandshakeFin packet, we didn't expect this yet."); "Received HandshakeFin packet, we didn't expect this yet.");
return; return;
} }
@ -277,24 +259,25 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
break; break;
default: default:
_error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}"); error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}");
break; break;
} }
} }
} }
private void SetReady()
void SetReady()
{ {
// done with credentials, null out the reference // done with credentials, null out the reference
_credentials = null; credentials = null;
_state = State.Ready; state = State.Ready;
_ready(); ready();
} }
private void ResetTimeouts() void ResetTimeouts()
{ {
_handshakeTimeout = 0; handshakeTimeout = 0;
_nextHandshakeResend = -1; nextHandshakeResend = -1;
} }
public void Send(ArraySegment<byte> data, int channel) public void Send(ArraySegment<byte> data, int channel)
@ -307,161 +290,149 @@ public void Send(ArraySegment<byte> data, int channel)
Profiler.EndSample(); Profiler.EndSample();
if (encrypted.Count == 0) if (encrypted.Count == 0)
{
// error // error
return; return;
}
writer.WriteBytes(encrypted.Array, 0, encrypted.Count); writer.WriteBytes(encrypted.Array, 0, encrypted.Count);
// write nonce after since Encrypt will update it // write nonce after since Encrypt will update it
writer.WriteBytes(_nonce, 0, NonceSize); writer.WriteBytes(nonce, 0, NonceSize);
_send(writer.ToArraySegment(), channel); send(writer.ToArraySegment(), channel);
} }
} }
private ArraySegment<byte> Encrypt(ArraySegment<byte> plaintext) ArraySegment<byte> Encrypt(ArraySegment<byte> plaintext)
{ {
if (plaintext.Count == 0) if (plaintext.Count == 0)
{
// Invalid // Invalid
return new ArraySegment<byte>(); return new ArraySegment<byte>();
}
// Need to make the nonce unique again before encrypting another message // Need to make the nonce unique again before encrypting another message
UpdateNonce(); UpdateNonce();
// Re-initialize the cipher with our cached parameters // Re-initialize the cipher with our cached parameters
Cipher.Init(true, _cipherParametersEncrypt); Cipher.Value.Init(true, cipherParametersEncrypt);
// Calculate the expected output size, this should always be input size + mac size // Calculate the expected output size, this should always be input size + mac size
int outSize = Cipher.GetOutputSize(plaintext.Count); int outSize = Cipher.Value.GetOutputSize(plaintext.Count);
#if UNITY_EDITOR #if UNITY_EDITOR
// expecting the outSize to be input size + MacSize // expecting the outSize to be input size + MacSize
if (outSize != plaintext.Count + MacSizeBytes) if (outSize != plaintext.Count + MacSizeBytes)
{
throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}"); throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}");
}
#endif #endif
// Resize the static buffer to fit // Resize the static buffer to fit
EnsureSize(ref _tmpCryptBuffer, outSize); byte[] cryptBuffer = TMPCryptBuffer.Value;
EnsureSize(ref cryptBuffer, outSize);
TMPCryptBuffer.Value = cryptBuffer;
int resultLen; int resultLen;
try try
{ {
// Run the plain text through the cipher, ProcessBytes will only process full blocks // Run the plain text through the cipher, ProcessBytes will only process full blocks
resultLen = resultLen =
Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, _tmpCryptBuffer, 0); Cipher.Value.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, cryptBuffer, 0);
// Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac) // Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac)
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen);
} }
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
// //
catch (Exception e) catch (Exception e)
{ {
_error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}"); error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}");
return new ArraySegment<byte>(); return new ArraySegment<byte>();
} }
#if UNITY_EDITOR #if UNITY_EDITOR
// expecting the result length to match the previously calculated input size + MacSize // expecting the result length to match the previously calculated input size + MacSize
if (resultLen != outSize) if (resultLen != outSize)
{
throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
}
#endif #endif
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen); return new ArraySegment<byte>(cryptBuffer, 0, resultLen);
} }
private ArraySegment<byte> Decrypt(ArraySegment<byte> ciphertext) ArraySegment<byte> Decrypt(ArraySegment<byte> ciphertext)
{ {
if (ciphertext.Count <= MacSizeBytes) if (ciphertext.Count <= MacSizeBytes)
{ {
_error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})"); error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})");
// Invalid // Invalid
return new ArraySegment<byte>(); return new ArraySegment<byte>();
} }
// Re-initialize the cipher with our cached parameters // Re-initialize the cipher with our cached parameters
Cipher.Init(false, _cipherParametersDecrypt); Cipher.Value.Init(false, cipherParametersDecrypt);
// Calculate the expected output size, this should always be input size - mac size // Calculate the expected output size, this should always be input size - mac size
int outSize = Cipher.GetOutputSize(ciphertext.Count); int outSize = Cipher.Value.GetOutputSize(ciphertext.Count);
#if UNITY_EDITOR #if UNITY_EDITOR
// expecting the outSize to be input size - MacSize // expecting the outSize to be input size - MacSize
if (outSize != ciphertext.Count - MacSizeBytes) if (outSize != ciphertext.Count - MacSizeBytes)
{
throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}"); throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}");
}
#endif #endif
// Resize the static buffer to fit
EnsureSize(ref _tmpCryptBuffer, outSize); byte[] cryptBuffer = TMPCryptBuffer.Value;
EnsureSize(ref cryptBuffer, outSize);
TMPCryptBuffer.Value = cryptBuffer;
int resultLen; int resultLen;
try try
{ {
// Run the ciphertext through the cipher, ProcessBytes will only process full blocks // Run the ciphertext through the cipher, ProcessBytes will only process full blocks
resultLen = resultLen =
Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, _tmpCryptBuffer, 0); Cipher.Value.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, cryptBuffer, 0);
// Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac) // Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac)
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen);
} }
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
catch (Exception e) catch (Exception e)
{ {
_error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data"); error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data");
return new ArraySegment<byte>(); return new ArraySegment<byte>();
} }
#if UNITY_EDITOR #if UNITY_EDITOR
// expecting the result length to match the previously calculated input size + MacSize // expecting the result length to match the previously calculated input size + MacSize
if (resultLen != outSize) if (resultLen != outSize)
{
throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
}
#endif #endif
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen); return new ArraySegment<byte>(cryptBuffer, 0, resultLen);
} }
private void UpdateNonce() void UpdateNonce()
{ {
// increment the nonce by one // increment the nonce by one
// we need to ensure the nonce is *always* unique and not reused // we need to ensure the nonce is *always* unique and not reused
// easiest way to do this is by simply incrementing it // easiest way to do this is by simply incrementing it
for (int i = 0; i < NonceSize; i++) for (int i = 0; i < NonceSize; i++)
{ {
_nonce[i]++; nonce[i]++;
if (_nonce[i] != 0) if (nonce[i] != 0)
{
break; break;
}
} }
} }
private static void EnsureSize(ref byte[] buffer, int size) static void EnsureSize(ref byte[] buffer, int size)
{ {
if (buffer.Length < size) if (buffer.Length < size)
{
// double buffer to avoid constantly resizing by a few bytes // double buffer to avoid constantly resizing by a few bytes
Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2)); Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2));
}
} }
private void SendHandshakeAndPubKey(OpCodes opcode) void SendHandshakeAndPubKey(OpCodes opcode)
{ {
using (NetworkWriterPooled writer = NetworkWriterPool.Get()) using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{ {
writer.WriteByte((byte)opcode); writer.WriteByte((byte)opcode);
if (opcode == OpCodes.HandshakeAck) if (opcode == OpCodes.HandshakeAck)
{ writer.WriteBytes(hkdfSalt, 0, HkdfSaltSize);
writer.WriteBytes(_hkdfSalt, 0, HkdfSaltSize); writer.WriteBytes(credentials.PublicKeySerialized, 0, credentials.PublicKeySerialized.Length);
} send(writer.ToArraySegment(), Channels.Unreliable);
writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length);
_send(writer.ToArraySegment(), Channels.Unreliable);
} }
} }
private void SendHandshakeFin() void SendHandshakeFin()
{ {
using (NetworkWriterPooled writer = NetworkWriterPool.Get()) using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{ {
writer.WriteByte((byte)OpCodes.HandshakeFin); writer.WriteByte((byte)OpCodes.HandshakeFin);
_send(writer.ToArraySegment(), Channels.Unreliable); send(writer.ToArraySegment(), Channels.Unreliable);
} }
} }
private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt) void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
{ {
AsymmetricKeyParameter remotePubKey; AsymmetricKeyParameter remotePubKey;
try try
@ -470,11 +441,11 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
} }
catch (Exception e) catch (Exception e)
{ {
_error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}"); error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}");
return; return;
} }
if (_validateRemoteKey != null) if (validateRemoteKey != null)
{ {
PubKeyInfo info = new PubKeyInfo PubKeyInfo info = new PubKeyInfo
{ {
@ -482,9 +453,9 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
Serialized = remotePubKeyRaw, Serialized = remotePubKeyRaw,
Key = remotePubKey Key = remotePubKey
}; };
if (!_validateRemoteKey(info)) if (!validateRemoteKey(info))
{ {
_error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. "); error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. ");
return; return;
} }
} }
@ -493,7 +464,7 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
// This gives us the same key on the other side, with our public key and their remote // This gives us the same key on the other side, with our public key and their remote
// It's like magic, but with math! // It's like magic, but with math!
ECDHBasicAgreement ecdh = new ECDHBasicAgreement(); ECDHBasicAgreement ecdh = new ECDHBasicAgreement();
ecdh.Init(_credentials.PrivateKey); ecdh.Init(credentials.PrivateKey);
byte[] sharedSecret; byte[] sharedSecret;
try try
{ {
@ -502,33 +473,33 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
catch catch
(Exception e) (Exception e)
{ {
_error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}"); error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}");
return; return;
} }
if (salt.Length != HkdfSaltSize) if (salt.Length != HkdfSaltSize)
{ {
_error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}."); error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}.");
return; return;
} }
Hkdf.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo)); Hkdf.Value.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo));
// Allocate a buffer for the output key // Allocate a buffer for the output key
byte[] keyRaw = new byte[KeyLength]; byte[] keyRaw = new byte[KeyLength];
// Generate the output keying material // Generate the output keying material
Hkdf.GenerateBytes(keyRaw, 0, keyRaw.Length); Hkdf.Value.GenerateBytes(keyRaw, 0, keyRaw.Length);
KeyParameter key = new KeyParameter(keyRaw); KeyParameter key = new KeyParameter(keyRaw);
// generate a starting nonce // generate a starting nonce
_nonce = GenerateSecureBytes(NonceSize); nonce = GenerateSecureBytes(NonceSize);
// we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance // we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance
// instead of creating a new one each encrypt/decrypt // instead of creating a new one each encrypt/decrypt
_cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, _nonce); cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, nonce);
_cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce.Value);
} }
/** /**
@ -537,53 +508,41 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
public void TickNonReady(double time) public void TickNonReady(double time)
{ {
if (IsReady) if (IsReady)
return;
// Timeout reset
if (handshakeTimeout == 0)
handshakeTimeout = time + DurationTimeout;
else if (time > handshakeTimeout)
{ {
error?.Invoke(TransportError.Timeout, $"Timed out during {state}, this probably just means the other side went away which is fine.");
return; return;
} }
// Timeout reset // Timeout reset
if (_handshakeTimeout == 0) if (nextHandshakeResend < 0)
{ {
_handshakeTimeout = time + DurationTimeout; nextHandshakeResend = time + DurationResend;
}
else if (time > _handshakeTimeout)
{
_error?.Invoke(TransportError.Timeout, $"Timed out during {_state}, this probably just means the other side went away which is fine.");
return; return;
} }
// Timeout reset if (time < nextHandshakeResend)
if (_nextHandshakeResend < 0)
{
_nextHandshakeResend = time + DurationResend;
return;
}
if (time < _nextHandshakeResend)
{
// Resend isn't due yet // Resend isn't due yet
return; return;
}
_nextHandshakeResend = time + DurationResend; nextHandshakeResend = time + DurationResend;
switch (_state) switch (state)
{ {
case State.WaitingHandshake: case State.WaitingHandshake:
if (_sendsFirst) if (sendsFirst)
{
SendHandshakeAndPubKey(OpCodes.HandshakeStart); SendHandshakeAndPubKey(OpCodes.HandshakeStart);
}
break; break;
case State.WaitingHandshakeReply: case State.WaitingHandshakeReply:
if (_sendsFirst) if (sendsFirst)
{
SendHandshakeFin(); SendHandshakeFin();
}
else else
{
SendHandshakeAndPubKey(OpCodes.HandshakeAck); SendHandshakeAndPubKey(OpCodes.HandshakeAck);
}
break; break;
case State.Ready: // IsReady is checked above & early-returned case State.Ready: // IsReady is checked above & early-returned

View File

@ -51,13 +51,11 @@ public static byte[] SerializePublicKey(AsymmetricKeyParameter publicKey)
return publicKeyInfo.ToAsn1Object().GetDerEncoded(); return publicKeyInfo.ToAsn1Object().GetDerEncoded();
} }
public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment<byte> pubKey) public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment<byte> pubKey) =>
{
// And then we do this to deserialize from the DER (from above) // And then we do this to deserialize from the DER (from above)
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
// to a byte[] first and then shoved through a MemoryStream // to a byte[] first and then shoved through a MemoryStream
return PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false)); PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false));
}
public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey) public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
{ {
@ -66,13 +64,11 @@ public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
return privateKeyInfo.ToAsn1Object().GetDerEncoded(); return privateKeyInfo.ToAsn1Object().GetDerEncoded();
} }
public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment<byte> privateKey) public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment<byte> privateKey) =>
{
// And then we do this to deserialize from the DER (from above) // And then we do this to deserialize from the DER (from above)
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
// to a byte[] first and then shoved through a MemoryStream // to a byte[] first and then shoved through a MemoryStream
return PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false)); PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false));
}
public static string PubKeyFingerprint(ArraySegment<byte> publicKeyBytes) public static string PubKeyFingerprint(ArraySegment<byte> publicKeyBytes)
{ {
@ -90,7 +86,7 @@ public void SaveToFile(string path)
{ {
PublicKeyFingerprint = PublicKeyFingerprint, PublicKeyFingerprint = PublicKeyFingerprint,
PublicKey = Convert.ToBase64String(PublicKeySerialized), PublicKey = Convert.ToBase64String(PublicKeySerialized),
PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)), PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey))
}); });
File.WriteAllText(path, json); File.WriteAllText(path, json);
} }
@ -104,9 +100,7 @@ public static EncryptionCredentials LoadFromFile(string path)
byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey); byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey);
if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment<byte>(publicKeyBytes))) if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment<byte>(publicKeyBytes)))
{
throw new Exception("Saved public key fingerprint does not match public key."); throw new Exception("Saved public key fingerprint does not match public key.");
}
return new EncryptionCredentials return new EncryptionCredentials
{ {
PublicKeySerialized = publicKeyBytes, PublicKeySerialized = publicKeyBytes,
@ -115,7 +109,7 @@ public static EncryptionCredentials LoadFromFile(string path)
}; };
} }
private class SerializedPair class SerializedPair
{ {
public string PublicKeyFingerprint; public string PublicKeyFingerprint;
public string PublicKey; public string PublicKey;

View File

@ -12,28 +12,27 @@ public class EncryptionTransport : Transport, PortTransport
{ {
public override bool IsEncrypted => true; public override bool IsEncrypted => true;
public override string EncryptionCipher => "AES256-GCM"; public override string EncryptionCipher => "AES256-GCM";
public Transport inner; [FormerlySerializedAs("inner")]
public Transport Inner;
public ushort Port public ushort Port
{ {
get get
{ {
if (inner is PortTransport portTransport) if (Inner is PortTransport portTransport)
{
return portTransport.Port; return portTransport.Port;
}
Debug.LogError($"EncryptionTransport can't get Port because {inner} is not a PortTransport"); Debug.LogError($"EncryptionTransport can't get Port because {Inner} is not a PortTransport");
return 0; return 0;
} }
set set
{ {
if (inner is PortTransport portTransport) if (Inner is PortTransport portTransport)
{ {
portTransport.Port = value; portTransport.Port = value;
return; return;
} }
Debug.LogError($"EncryptionTransport can't set Port because {inner} is not a PortTransport"); Debug.LogError($"EncryptionTransport can't set Port because {Inner} is not a PortTransport");
} }
} }
@ -41,81 +40,78 @@ public enum ValidationMode
{ {
Off, Off,
List, List,
Callback, Callback
} }
public ValidationMode clientValidateServerPubKey; [FormerlySerializedAs("clientValidateServerPubKey")]
public ValidationMode ClientValidateServerPubKey;
[FormerlySerializedAs("clientTrustedPubKeySignatures")]
[Tooltip("List of public key fingerprints the client will accept")] [Tooltip("List of public key fingerprints the client will accept")]
public string[] clientTrustedPubKeySignatures; public string[] ClientTrustedPubKeySignatures;
public Func<PubKeyInfo, bool> onClientValidateServerPubKey; public Func<PubKeyInfo, bool> OnClientValidateServerPubKey;
public bool serverLoadKeyPairFromFile; [FormerlySerializedAs("serverLoadKeyPairFromFile")]
public string serverKeypairPath = "./server-keys.json"; public bool ServerLoadKeyPairFromFile;
[FormerlySerializedAs("serverKeypairPath")]
public string ServerKeypairPath = "./server-keys.json";
private EncryptedConnection _client; EncryptedConnection client;
private Dictionary<int, EncryptedConnection> _serverConnections = new Dictionary<int, EncryptedConnection>(); readonly Dictionary<int, EncryptedConnection> serverConnections = new Dictionary<int, EncryptedConnection>();
private List<EncryptedConnection> _serverPendingConnections = readonly List<EncryptedConnection> serverPendingConnections =
new List<EncryptedConnection>(); new List<EncryptedConnection>();
private EncryptionCredentials _credentials; EncryptionCredentials credentials;
public string EncryptionPublicKeyFingerprint => _credentials?.PublicKeyFingerprint; public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint;
public byte[] EncryptionPublicKey => _credentials?.PublicKeySerialized; public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized;
private void ServerRemoveFromPending(EncryptedConnection con) void ServerRemoveFromPending(EncryptedConnection con)
{ {
for (int i = 0; i < _serverPendingConnections.Count; i++) for (int i = 0; i < serverPendingConnections.Count; i++)
{ if (serverPendingConnections[i] == con)
if (_serverPendingConnections[i] == con)
{ {
// remove by swapping with last // remove by swapping with last
int lastIndex = _serverPendingConnections.Count - 1; int lastIndex = serverPendingConnections.Count - 1;
_serverPendingConnections[i] = _serverPendingConnections[lastIndex]; serverPendingConnections[i] = serverPendingConnections[lastIndex];
_serverPendingConnections.RemoveAt(lastIndex); serverPendingConnections.RemoveAt(lastIndex);
break; break;
} }
}
} }
private void HandleInnerServerDisconnected(int connId) void HandleInnerServerDisconnected(int connId)
{ {
if (_serverConnections.TryGetValue(connId, out EncryptedConnection con)) if (serverConnections.TryGetValue(connId, out EncryptedConnection con))
{ {
ServerRemoveFromPending(con); ServerRemoveFromPending(con);
_serverConnections.Remove(connId); serverConnections.Remove(connId);
} }
OnServerDisconnected?.Invoke(connId); OnServerDisconnected?.Invoke(connId);
} }
private void HandleInnerServerError(int connId, TransportError type, string msg) void HandleInnerServerError(int connId, TransportError type, string msg) => OnServerError?.Invoke(connId, type, $"inner: {msg}");
{
OnServerError?.Invoke(connId, type, $"inner: {msg}");
}
private void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel) void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
{ {
if (_serverConnections.TryGetValue(connId, out EncryptedConnection c)) if (serverConnections.TryGetValue(connId, out EncryptedConnection c))
{
c.OnReceiveRaw(data, channel); c.OnReceiveRaw(data, channel);
}
} }
private void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, inner.ServerGetClientAddress(connId)); void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId));
private void HandleInnerServerConnected(int connId, string clientRemoteAddress) void HandleInnerServerConnected(int connId, string clientRemoteAddress)
{ {
Debug.Log($"[EncryptionTransport] New connection #{connId}"); Debug.Log($"[EncryptionTransport] New connection #{connId} from {clientRemoteAddress}");
EncryptedConnection ec = null; EncryptedConnection ec = null;
ec = new EncryptedConnection( ec = new EncryptedConnection(
_credentials, credentials,
false, false,
(segment, channel) => inner.ServerSend(connId, segment, channel), (segment, channel) => Inner.ServerSend(connId, segment, channel),
(segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel), (segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel),
() => () =>
{ {
Debug.Log($"[EncryptionTransport] Connection #{connId} is ready"); Debug.Log($"[EncryptionTransport] Connection #{connId} is ready");
// ReSharper disable once AccessToModifiedClosure
ServerRemoveFromPending(ec); ServerRemoveFromPending(ec);
//OnServerConnected?.Invoke(connId);
OnServerConnectedWithAddress?.Invoke(connId, clientRemoteAddress); OnServerConnectedWithAddress?.Invoke(connId, clientRemoteAddress);
}, },
(type, msg) => (type, msg) =>
@ -123,32 +119,25 @@ private void HandleInnerServerConnected(int connId, string clientRemoteAddress)
OnServerError?.Invoke(connId, type, msg); OnServerError?.Invoke(connId, type, msg);
ServerDisconnect(connId); ServerDisconnect(connId);
}); });
_serverConnections.Add(connId, ec); serverConnections.Add(connId, ec);
_serverPendingConnections.Add(ec); serverPendingConnections.Add(ec);
} }
private void HandleInnerClientDisconnected() void HandleInnerClientDisconnected()
{ {
_client = null; client = null;
OnClientDisconnected?.Invoke(); OnClientDisconnected?.Invoke();
} }
private void HandleInnerClientError(TransportError arg1, string arg2) void HandleInnerClientError(TransportError arg1, string arg2) => OnClientError?.Invoke(arg1, $"inner: {arg2}");
{
OnClientError?.Invoke(arg1, $"inner: {arg2}");
}
private void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel) void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel) => client?.OnReceiveRaw(data, channel);
{
_client?.OnReceiveRaw(data, channel);
}
private void HandleInnerClientConnected() void HandleInnerClientConnected() =>
{ client = new EncryptedConnection(
_client = new EncryptedConnection( credentials,
_credentials,
true, true,
(segment, channel) => inner.ClientSend(segment, channel), (segment, channel) => Inner.ClientSend(segment, channel),
(segment, channel) => OnClientDataReceived?.Invoke(segment, channel), (segment, channel) => OnClientDataReceived?.Invoke(segment, channel),
() => () =>
{ {
@ -160,25 +149,23 @@ private void HandleInnerClientConnected()
ClientDisconnect(); ClientDisconnect();
}, },
HandleClientValidateServerPubKey); HandleClientValidateServerPubKey);
}
private bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
{ {
switch (clientValidateServerPubKey) switch (ClientValidateServerPubKey)
{ {
case ValidationMode.Off: case ValidationMode.Off:
return true; return true;
case ValidationMode.List: case ValidationMode.List:
return Array.IndexOf(clientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
case ValidationMode.Callback: case ValidationMode.Callback:
return onClientValidateServerPubKey(pubKeyInfo); return OnClientValidateServerPubKey(pubKeyInfo);
default: default:
throw new ArgumentOutOfRangeException(); throw new ArgumentOutOfRangeException();
} }
} }
void Awake() void Awake() =>
{
// check if encryption via hardware acceleration is supported. // check if encryption via hardware acceleration is supported.
// this can be useful to know for low end devices. // this can be useful to know for low end devices.
// //
@ -188,27 +175,26 @@ void Awake()
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs // https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs
// which Unity does not support yet. // which Unity does not support yet.
Debug.Log($"EncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}"); Debug.Log($"EncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}");
}
public override bool Available() => inner.Available(); public override bool Available() => Inner.Available();
public override bool ClientConnected() => _client != null && _client.IsReady; public override bool ClientConnected() => client != null && client.IsReady;
public override void ClientConnect(string address) public override void ClientConnect(string address)
{ {
switch (clientValidateServerPubKey) switch (ClientValidateServerPubKey)
{ {
case ValidationMode.Off: case ValidationMode.Off:
break; break;
case ValidationMode.List: case ValidationMode.List:
if (clientTrustedPubKeySignatures == null || clientTrustedPubKeySignatures.Length == 0) if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0)
{ {
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty."); OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty.");
return; return;
} }
break; break;
case ValidationMode.Callback: case ValidationMode.Callback:
if (onClientValidateServerPubKey == null) if (OnClientValidateServerPubKey == null)
{ {
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set"); OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set");
return; return;
@ -217,95 +203,79 @@ public override void ClientConnect(string address)
default: default:
throw new ArgumentOutOfRangeException(); throw new ArgumentOutOfRangeException();
} }
_credentials = EncryptionCredentials.Generate(); credentials = EncryptionCredentials.Generate();
inner.OnClientConnected = HandleInnerClientConnected; Inner.OnClientConnected = HandleInnerClientConnected;
inner.OnClientDataReceived = HandleInnerClientDataReceived; Inner.OnClientDataReceived = HandleInnerClientDataReceived;
inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel); Inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel);
inner.OnClientError = HandleInnerClientError; Inner.OnClientError = HandleInnerClientError;
inner.OnClientDisconnected = HandleInnerClientDisconnected; Inner.OnClientDisconnected = HandleInnerClientDisconnected;
inner.ClientConnect(address); Inner.ClientConnect(address);
} }
public override void ClientSend(ArraySegment<byte> segment, int channelId = Channels.Reliable) => public override void ClientSend(ArraySegment<byte> segment, int channelId = Channels.Reliable) =>
_client?.Send(segment, channelId); client?.Send(segment, channelId);
public override void ClientDisconnect() => inner.ClientDisconnect(); public override void ClientDisconnect() => Inner.ClientDisconnect();
public override Uri ServerUri() => inner.ServerUri(); public override Uri ServerUri() => Inner.ServerUri();
public override bool ServerActive() => inner.ServerActive(); public override bool ServerActive() => Inner.ServerActive();
public override void ServerStart() public override void ServerStart()
{ {
if (serverLoadKeyPairFromFile) if (ServerLoadKeyPairFromFile)
{ credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath);
_credentials = EncryptionCredentials.LoadFromFile(serverKeypairPath);
}
else else
{ credentials = EncryptionCredentials.Generate();
_credentials = EncryptionCredentials.Generate();
}
#pragma warning disable CS0618 // Type or member is obsolete #pragma warning disable CS0618 // Type or member is obsolete
inner.OnServerConnected = HandleInnerServerConnected; Inner.OnServerConnected = HandleInnerServerConnected;
#pragma warning restore CS0618 // Type or member is obsolete #pragma warning restore CS0618 // Type or member is obsolete
inner.OnServerConnectedWithAddress = HandleInnerServerConnected; Inner.OnServerConnectedWithAddress = HandleInnerServerConnected;
inner.OnServerDataReceived = HandleInnerServerDataReceived; Inner.OnServerDataReceived = HandleInnerServerDataReceived;
inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel); Inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel);
inner.OnServerError = HandleInnerServerError; Inner.OnServerError = HandleInnerServerError;
inner.OnServerDisconnected = HandleInnerServerDisconnected; Inner.OnServerDisconnected = HandleInnerServerDisconnected;
inner.ServerStart(); Inner.ServerStart();
} }
public override void ServerSend(int connectionId, ArraySegment<byte> segment, int channelId = Channels.Reliable) public override void ServerSend(int connectionId, ArraySegment<byte> segment, int channelId = Channels.Reliable)
{ {
if (_serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
{
connection.Send(segment, channelId); connection.Send(segment, channelId);
}
} }
public override void ServerDisconnect(int connectionId) public override void ServerDisconnect(int connectionId) =>
{
// cleanup is done via inners disconnect event // cleanup is done via inners disconnect event
inner.ServerDisconnect(connectionId); Inner.ServerDisconnect(connectionId);
}
public override string ServerGetClientAddress(int connectionId) => inner.ServerGetClientAddress(connectionId); public override string ServerGetClientAddress(int connectionId) => Inner.ServerGetClientAddress(connectionId);
public override void ServerStop() => inner.ServerStop(); public override void ServerStop() => Inner.ServerStop();
public override int GetMaxPacketSize(int channelId = Channels.Reliable) => public override int GetMaxPacketSize(int channelId = Channels.Reliable) =>
inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
public override void Shutdown() => inner.Shutdown(); public override void Shutdown() => Inner.Shutdown();
public override void ClientEarlyUpdate() public override void ClientEarlyUpdate() => Inner.ClientEarlyUpdate();
{
inner.ClientEarlyUpdate();
}
public override void ClientLateUpdate() public override void ClientLateUpdate()
{ {
inner.ClientLateUpdate(); Inner.ClientLateUpdate();
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
_client?.TickNonReady(NetworkTime.localTime); client?.TickNonReady(NetworkTime.localTime);
Profiler.EndSample(); Profiler.EndSample();
} }
public override void ServerEarlyUpdate() public override void ServerEarlyUpdate() => Inner.ServerEarlyUpdate();
{
inner.ServerEarlyUpdate();
}
public override void ServerLateUpdate() public override void ServerLateUpdate()
{ {
inner.ServerLateUpdate(); Inner.ServerLateUpdate();
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
// Reverse iteration as entries can be removed while updating // Reverse iteration as entries can be removed while updating
for (int i = _serverPendingConnections.Count - 1; i >= 0; i--) for (int i = serverPendingConnections.Count - 1; i >= 0; i--)
{ serverPendingConnections[i].TickNonReady(NetworkTime.time);
_serverPendingConnections[i].TickNonReady(NetworkTime.time);
}
Profiler.EndSample(); Profiler.EndSample();
} }
} }

View File

@ -0,0 +1,314 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using Mirror.BouncyCastle.Crypto;
using UnityEngine;
using UnityEngine.Profiling;
using UnityEngine.Serialization;
using Debug = UnityEngine.Debug;
namespace Mirror.Transports.Encryption
{
[HelpURL("https://mirror-networking.gitbook.io/docs/manual/transports/encryption-transport")]
public class ThreadedEncryptionTransport : ThreadedTransport, PortTransport
{
public override bool IsEncrypted => true;
public override string EncryptionCipher => "AES256-GCM";
[FormerlySerializedAs("inner")]
public ThreadedTransport Inner;
public ushort Port
{
get
{
if (Inner is PortTransport portTransport)
return portTransport.Port;
Debug.LogError($"ThreadedEncryptionTransport can't get Port because {Inner} is not a PortTransport");
return 0;
}
set
{
if (Inner is PortTransport portTransport)
{
portTransport.Port = value;
return;
}
Debug.LogError($"ThreadedEncryptionTransport can't set Port because {Inner} is not a PortTransport");
}
}
public enum ValidationMode
{
Off,
List,
Callback
}
[FormerlySerializedAs("clientValidateServerPubKey")]
public ValidationMode ClientValidateServerPubKey;
[FormerlySerializedAs("clientTrustedPubKeySignatures")]
[Tooltip("List of public key fingerprints the client will accept")]
public string[] ClientTrustedPubKeySignatures;
/// <summary>
/// Called when a client connects to a server
/// ATTENTION: NOT THREAD SAFE.
/// This will be called on the worker thread.
/// </summary>
public Func<PubKeyInfo, bool> OnClientValidateServerPubKey;
[FormerlySerializedAs("serverLoadKeyPairFromFile")]
public bool ServerLoadKeyPairFromFile;
[FormerlySerializedAs("serverKeypairPath")]
public string ServerKeypairPath = "./server-keys.json";
EncryptedConnection client;
readonly Dictionary<int, EncryptedConnection> serverConnections = new Dictionary<int, EncryptedConnection>();
readonly List<EncryptedConnection> serverPendingConnections =
new List<EncryptedConnection>();
EncryptionCredentials credentials;
public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint;
public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized;
// Used for threaded time keeping as unitys Time.time is not thread safe
Stopwatch stopwatch = Stopwatch.StartNew();
void ServerRemoveFromPending(EncryptedConnection con)
{
for (int i = 0; i < serverPendingConnections.Count; i++)
if (serverPendingConnections[i] == con)
{
// remove by swapping with last
int lastIndex = serverPendingConnections.Count - 1;
serverPendingConnections[i] = serverPendingConnections[lastIndex];
serverPendingConnections.RemoveAt(lastIndex);
break;
}
}
void HandleInnerServerDisconnected(int connId)
{
if (serverConnections.TryGetValue(connId, out EncryptedConnection con))
{
ServerRemoveFromPending(con);
serverConnections.Remove(connId);
}
OnThreadedServerDisconnected(connId);
}
void HandleInnerServerError(int connId, TransportError type, string msg) => OnThreadedServerError(connId, type, $"inner: {msg}");
void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
{
if (serverConnections.TryGetValue(connId, out EncryptedConnection c))
c.OnReceiveRaw(data, channel);
}
void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId));
void HandleInnerServerConnected(int connId, string clientRemoteAddress)
{
Debug.Log($"[ThreadedEncryptionTransport] New connection #{connId} from {clientRemoteAddress}");
EncryptedConnection ec = null;
ec = new EncryptedConnection(
credentials,
false,
(segment, channel) => Inner.ServerSend(connId, segment, channel),
(segment, channel) => OnThreadedServerReceive(connId, segment, channel),
() =>
{
Debug.Log($"[ThreadedEncryptionTransport] Connection #{connId} is ready");
// ReSharper disable once AccessToModifiedClosure
ServerRemoveFromPending(ec);
OnThreadedServerConnected(connId, new IPEndPoint(IPAddress.Parse(clientRemoteAddress), 0));
},
(type, msg) =>
{
OnThreadedServerError(connId, type, msg);
ServerDisconnect(connId);
});
serverConnections.Add(connId, ec);
serverPendingConnections.Add(ec);
}
void HandleInnerClientDisconnected()
{
client = null;
OnThreadedClientDisconnected();
}
void HandleInnerClientError(TransportError arg1, string arg2) => OnThreadedClientError(arg1, $"inner: {arg2}");
void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel) => client?.OnReceiveRaw(data, channel);
void HandleInnerClientConnected() =>
client = new EncryptedConnection(
credentials,
true,
(segment, channel) => Inner.ClientSend(segment, channel),
(segment, channel) => OnThreadedClientReceive(segment, channel),
() =>
{
OnThreadedClientConnected();
},
(type, msg) =>
{
OnThreadedClientError(type, msg);
ClientDisconnect();
},
HandleClientValidateServerPubKey);
bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
{
switch (ClientValidateServerPubKey)
{
case ValidationMode.Off:
return true;
case ValidationMode.List:
return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
case ValidationMode.Callback:
return OnClientValidateServerPubKey(pubKeyInfo);
default:
throw new ArgumentOutOfRangeException();
}
}
protected override void Awake()
{
base.Awake();
// check if encryption via hardware acceleration is supported.
// this can be useful to know for low end devices.
//
// hardware acceleration requires netcoreapp3.0 or later:
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/AesUtilities.cs#L18
// because AesEngine_x86 requires System.Runtime.Intrinsics.X86:
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs
// which Unity does not support yet.
Debug.Log($"ThreadedEncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}");
}
public override bool Available() => Inner.Available();
protected override void ThreadedClientConnect(string address)
{
switch (ClientValidateServerPubKey)
{
case ValidationMode.Off:
break;
case ValidationMode.List:
if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0)
{
OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty.");
return;
}
break;
case ValidationMode.Callback:
if (OnClientValidateServerPubKey == null)
{
OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set");
return;
}
break;
default:
throw new ArgumentOutOfRangeException();
}
credentials = EncryptionCredentials.Generate();
Inner.OnClientConnected = HandleInnerClientConnected;
Inner.OnClientDataReceived = HandleInnerClientDataReceived;
Inner.OnClientDataSent = (bytes, channel) => OnThreadedClientSend(bytes, channel);
Inner.OnClientError = HandleInnerClientError;
Inner.OnClientDisconnected = HandleInnerClientDisconnected;
Inner.ClientConnect(address);
}
protected override void ThreadedClientConnect(Uri address) => Inner.ClientConnect(address);
protected override void ThreadedClientSend(ArraySegment<byte> segment, int channelId) =>
client?.Send(segment, channelId);
protected override void ThreadedClientDisconnect() => Inner.ClientDisconnect();
protected override void ThreadedServerStart()
{
if (ServerLoadKeyPairFromFile)
credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath);
else
credentials = EncryptionCredentials.Generate();
#pragma warning disable CS0618 // Type or member is obsolete
Inner.OnServerConnected = HandleInnerServerConnected;
#pragma warning restore CS0618 // Type or member is obsolete
Inner.OnServerConnectedWithAddress = HandleInnerServerConnected;
Inner.OnServerDataReceived = HandleInnerServerDataReceived;
Inner.OnServerDataSent = (connId, bytes, channel) => OnThreadedServerSend(connId, bytes, channel);
Inner.OnServerError = HandleInnerServerError;
Inner.OnServerDisconnected = HandleInnerServerDisconnected;
Inner.ServerStart();
}
protected override void ThreadedServerSend(int connectionId, ArraySegment<byte> segment, int channelId)
{
if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
connection.Send(segment, channelId);
}
protected override void ThreadedServerDisconnect(int connectionId) =>
// cleanup is done via inners disconnect event
Inner.ServerDisconnect(connectionId);
protected override void ThreadedClientEarlyUpdate() {}
protected override void ThreadedServerStop() => Inner.ServerStop();
public override Uri ServerUri() => Inner.ServerUri();
public override int GetMaxPacketSize(int channelId = Channels.Reliable) =>
Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
protected override void ThreadedShutdown() => Inner.Shutdown();
public override void ClientEarlyUpdate()
{
base.ClientEarlyUpdate();
Inner.ClientEarlyUpdate();
}
public override void ClientLateUpdate()
{
base.ClientLateUpdate();
Inner.ClientLateUpdate();
}
protected override void ThreadedClientLateUpdate()
{
Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate");
client?.TickNonReady(stopwatch.Elapsed.TotalSeconds);
Profiler.EndSample();
}
protected override void ThreadedServerEarlyUpdate() {}
public override void ServerEarlyUpdate()
{
base.ServerEarlyUpdate();
Inner.ServerEarlyUpdate();
}
public override void ServerLateUpdate()
{
base.ServerLateUpdate();
Inner.ServerLateUpdate();
}
protected override void ThreadedServerLateUpdate()
{
Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate");
// Reverse iteration as entries can be removed while updating
for (int i = serverPendingConnections.Count - 1; i >= 0; i--)
serverPendingConnections[i].TickNonReady(stopwatch.Elapsed.TotalSeconds);
Profiler.EndSample();
}
}
}

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 5d3e310924fb49c195391b9699f20809
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3}
userData:
assetBundleName:
assetBundleVariant: