mirror of
https://github.com/MirrorNetworking/Mirror.git
synced 2024-11-17 18:40:33 +00:00
Compare commits
5 Commits
409064db94
...
f7c38d6958
Author | SHA1 | Date | |
---|---|---|---|
|
f7c38d6958 | ||
|
1187a59b18 | ||
|
499e4daea3 | ||
|
c5824f39e1 | ||
|
7ebbaa0319 |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,11 @@
|
|||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 8f63ea2e505fd484193fb31c5c55ca73
|
||||||
|
MonoImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
serializedVersion: 2
|
||||||
|
defaultReferences: []
|
||||||
|
executionOrder: 0
|
||||||
|
icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
@ -863,7 +863,7 @@ internal void OnStopLocalPlayer()
|
|||||||
for (int i = 0; i < components.Length; ++i)
|
for (int i = 0; i < components.Length; ++i)
|
||||||
{
|
{
|
||||||
NetworkBehaviour component = components[i];
|
NetworkBehaviour component = components[i];
|
||||||
ulong nthBit = (1u << i);
|
ulong nthBit = 1ul << i;
|
||||||
|
|
||||||
bool dirty = component.IsDirty();
|
bool dirty = component.IsDirty();
|
||||||
|
|
||||||
@ -910,7 +910,7 @@ ulong ClientDirtyMask()
|
|||||||
|
|
||||||
// on client, only consider owned components with SyncDirection to server
|
// on client, only consider owned components with SyncDirection to server
|
||||||
NetworkBehaviour component = components[i];
|
NetworkBehaviour component = components[i];
|
||||||
ulong nthBit = (1u << i);
|
ulong nthBit = 1ul << i;
|
||||||
|
|
||||||
if (isOwned && component.syncDirection == SyncDirection.ClientToServer)
|
if (isOwned && component.syncDirection == SyncDirection.ClientToServer)
|
||||||
{
|
{
|
||||||
@ -928,7 +928,7 @@ ulong ClientDirtyMask()
|
|||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
internal static bool IsDirty(ulong mask, int index)
|
internal static bool IsDirty(ulong mask, int index)
|
||||||
{
|
{
|
||||||
ulong nthBit = (ulong)(1 << index);
|
ulong nthBit = 1ul << index;
|
||||||
return (mask & nthBit) != 0;
|
return (mask & nthBit) != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// base class for networking tests to make things easier.
|
// base class for networking tests to make things easier.
|
||||||
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
using NUnit.Framework;
|
using NUnit.Framework;
|
||||||
@ -81,6 +82,18 @@ protected void CreateNetworked(out GameObject go, out NetworkIdentity identity)
|
|||||||
instantiated.Add(go);
|
instantiated.Add(go);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void CreateNetworked(out GameObject go, out NetworkIdentity identity, Action<NetworkIdentity> preAwake)
|
||||||
|
{
|
||||||
|
go = new GameObject();
|
||||||
|
identity = go.AddComponent<NetworkIdentity>();
|
||||||
|
preAwake(identity);
|
||||||
|
// Awake is only called in play mode.
|
||||||
|
// call manually for initialization.
|
||||||
|
identity.Awake();
|
||||||
|
// track
|
||||||
|
instantiated.Add(go);
|
||||||
|
}
|
||||||
|
|
||||||
// create GameObject + NetworkIdentity + NetworkBehaviour<T>
|
// create GameObject + NetworkIdentity + NetworkBehaviour<T>
|
||||||
// add to tracker list if needed (useful for cleanups afterwards)
|
// add to tracker list if needed (useful for cleanups afterwards)
|
||||||
protected void CreateNetworked<T>(out GameObject go, out NetworkIdentity identity, out T component)
|
protected void CreateNetworked<T>(out GameObject go, out NetworkIdentity identity, out T component)
|
||||||
@ -269,6 +282,44 @@ protected void CreateNetworkedAndSpawn(
|
|||||||
Assert.That(NetworkClient.spawned.ContainsKey(serverIdentity.netId));
|
Assert.That(NetworkClient.spawned.ContainsKey(serverIdentity.netId));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create GameObject + NetworkIdentity + NetworkBehaviour & SPAWN
|
||||||
|
// => preAwake callbacks can be used to add network behaviours to the NI
|
||||||
|
// => ownerConnection can be NetworkServer.localConnection if needed.
|
||||||
|
// => returns objects from client and from server.
|
||||||
|
// will be same in host mode.
|
||||||
|
protected void CreateNetworkedAndSpawn(
|
||||||
|
out GameObject serverGO, out NetworkIdentity serverIdentity, Action<NetworkIdentity> serverPreAwake,
|
||||||
|
out GameObject clientGO, out NetworkIdentity clientIdentity, Action<NetworkIdentity> clientPreAwake,
|
||||||
|
NetworkConnectionToClient ownerConnection = null)
|
||||||
|
{
|
||||||
|
// server & client need to be active before spawning
|
||||||
|
Debug.Assert(NetworkClient.active, "NetworkClient needs to be active before spawning.");
|
||||||
|
Debug.Assert(NetworkServer.active, "NetworkServer needs to be active before spawning.");
|
||||||
|
|
||||||
|
// create one on server, one on client
|
||||||
|
// (spawning has to find it on client, it doesn't create it)
|
||||||
|
CreateNetworked(out serverGO, out serverIdentity, serverPreAwake);
|
||||||
|
CreateNetworked(out clientGO, out clientIdentity, clientPreAwake);
|
||||||
|
|
||||||
|
// give both a scene id and register it on client for spawnables
|
||||||
|
clientIdentity.sceneId = serverIdentity.sceneId = (ulong)serverGO.GetHashCode();
|
||||||
|
NetworkClient.spawnableObjects[clientIdentity.sceneId] = clientIdentity;
|
||||||
|
|
||||||
|
// spawn
|
||||||
|
NetworkServer.Spawn(serverGO, ownerConnection);
|
||||||
|
ProcessMessages();
|
||||||
|
|
||||||
|
// double check isServer/isClient. avoids debugging headaches.
|
||||||
|
Assert.That(serverIdentity.isServer, Is.True);
|
||||||
|
Assert.That(clientIdentity.isClient, Is.True);
|
||||||
|
|
||||||
|
// double check that we have authority if we passed an owner connection
|
||||||
|
if (ownerConnection != null)
|
||||||
|
Debug.Assert(clientIdentity.isOwned == true, $"Behaviour Had Wrong Authority when spawned, This means that the test is broken and will give the wrong results");
|
||||||
|
|
||||||
|
// make sure the client really spawned it.
|
||||||
|
Assert.That(NetworkClient.spawned.ContainsKey(serverIdentity.netId));
|
||||||
|
}
|
||||||
// create GameObject + NetworkIdentity + NetworkBehaviour & SPAWN
|
// create GameObject + NetworkIdentity + NetworkBehaviour & SPAWN
|
||||||
// => ownerConnection can be NetworkServer.localConnection if needed.
|
// => ownerConnection can be NetworkServer.localConnection if needed.
|
||||||
protected void CreateNetworkedAndSpawn<T>(out GameObject go, out NetworkIdentity identity, out T component, NetworkConnectionToClient ownerConnection = null)
|
protected void CreateNetworkedAndSpawn<T>(out GameObject go, out NetworkIdentity identity, out T component, NetworkConnectionToClient ownerConnection = null)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// OnDe/SerializeSafely tests.
|
// OnDe/SerializeSafely tests.
|
||||||
|
using System.Collections.Generic;
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
using Mirror.Tests.EditorBehaviours.NetworkIdentities;
|
using Mirror.Tests.EditorBehaviours.NetworkIdentities;
|
||||||
using NUnit.Framework;
|
using NUnit.Framework;
|
||||||
@ -42,7 +43,7 @@ public void SerializeAndDeserializeAll()
|
|||||||
);
|
);
|
||||||
|
|
||||||
// set sync modes
|
// set sync modes
|
||||||
serverOwnerComp.syncMode = clientOwnerComp.syncMode = SyncMode.Owner;
|
serverOwnerComp.syncMode = clientOwnerComp.syncMode = SyncMode.Owner;
|
||||||
serverObserversComp.syncMode = clientObserversComp.syncMode = SyncMode.Observers;
|
serverObserversComp.syncMode = clientObserversComp.syncMode = SyncMode.Observers;
|
||||||
|
|
||||||
// set unique values on server components
|
// set unique values on server components
|
||||||
@ -65,10 +66,127 @@ public void SerializeAndDeserializeAll()
|
|||||||
// deserialize client object with OBSERVERS payload
|
// deserialize client object with OBSERVERS payload
|
||||||
reader = new NetworkReader(observersWriter.ToArray());
|
reader = new NetworkReader(observersWriter.ToArray());
|
||||||
clientIdentity.DeserializeClient(reader, true);
|
clientIdentity.DeserializeClient(reader, true);
|
||||||
Assert.That(clientOwnerComp.value, Is.EqualTo(null)); // owner mode shouldn't be in data
|
Assert.That(clientOwnerComp.value, Is.EqualTo(null)); // owner mode shouldn't be in data
|
||||||
Assert.That(clientObserversComp.value, Is.EqualTo(42)); // observers mode should be in data
|
Assert.That(clientObserversComp.value, Is.EqualTo(42)); // observers mode should be in data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test serialize -> deserialize of any supported number of components
|
||||||
|
[Test]
|
||||||
|
public void SerializeAndDeserializeN([NUnit.Framework.Range(1, 64)] int numberOfNBs)
|
||||||
|
{
|
||||||
|
List<SerializeTest1NetworkBehaviour> serverNBs = new List<SerializeTest1NetworkBehaviour>();
|
||||||
|
List<SerializeTest1NetworkBehaviour> clientNBs = new List<SerializeTest1NetworkBehaviour>();
|
||||||
|
// need two of both versions so we can serialize -> deserialize
|
||||||
|
CreateNetworkedAndSpawn(
|
||||||
|
out _, out NetworkIdentity serverIdentity, ni =>
|
||||||
|
{
|
||||||
|
for (int i = 0; i < numberOfNBs; i++)
|
||||||
|
{
|
||||||
|
SerializeTest1NetworkBehaviour nb = ni.gameObject.AddComponent<SerializeTest1NetworkBehaviour>();
|
||||||
|
nb.syncInterval = 0;
|
||||||
|
nb.syncMode = SyncMode.Observers;
|
||||||
|
serverNBs.Add(nb);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
out _, out NetworkIdentity clientIdentity, ni =>
|
||||||
|
{
|
||||||
|
for (int i = 0; i < numberOfNBs; i++)
|
||||||
|
{
|
||||||
|
SerializeTest1NetworkBehaviour nb = ni.gameObject.AddComponent<SerializeTest1NetworkBehaviour>();
|
||||||
|
nb.syncInterval = 0;
|
||||||
|
nb.syncMode = SyncMode.Observers;
|
||||||
|
clientNBs.Add(nb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// INITIAL SYNC
|
||||||
|
// set unique values on server components
|
||||||
|
for (int i = 0; i < serverNBs.Count; i++)
|
||||||
|
{
|
||||||
|
serverNBs[i].value = (i + 1) * 3;
|
||||||
|
serverNBs[i].SetDirty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// serialize server object
|
||||||
|
serverIdentity.SerializeServer(true, ownerWriter, observersWriter);
|
||||||
|
|
||||||
|
// deserialize client object with OBSERVERS payload
|
||||||
|
NetworkReader reader = new NetworkReader(observersWriter.ToArray());
|
||||||
|
clientIdentity.DeserializeClient(reader, true);
|
||||||
|
for (int i = 0; i < clientNBs.Count; i++)
|
||||||
|
{
|
||||||
|
int expected = (i + 1) * 3;
|
||||||
|
Assert.That(clientNBs[i].value, Is.EqualTo(expected), $"Expected the clientNBs[{i}] to have a value of {expected}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear dirty bits for incremental sync
|
||||||
|
foreach (SerializeTest1NetworkBehaviour serverNB in serverNBs)
|
||||||
|
serverNB.ClearAllDirtyBits();
|
||||||
|
|
||||||
|
// INCREMENTAL SYNC ALL
|
||||||
|
// set unique values on server components
|
||||||
|
for (int i = 0; i < serverNBs.Count; i++)
|
||||||
|
{
|
||||||
|
serverNBs[i].value = (i + 1) * 11;
|
||||||
|
serverNBs[i].SetDirty();
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerWriter.Reset();
|
||||||
|
observersWriter.Reset();
|
||||||
|
// serialize server object
|
||||||
|
serverIdentity.SerializeServer(false, ownerWriter, observersWriter);
|
||||||
|
|
||||||
|
// deserialize client object with OBSERVERS payload
|
||||||
|
reader = new NetworkReader(observersWriter.ToArray());
|
||||||
|
clientIdentity.DeserializeClient(reader, false);
|
||||||
|
for (int i = 0; i < clientNBs.Count; i++)
|
||||||
|
{
|
||||||
|
int expected = (i + 1) * 11;
|
||||||
|
Assert.That(clientNBs[i].value, Is.EqualTo(expected), $"Expected the clientNBs[{i}] to have a value of {expected}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear dirty bits for incremental sync
|
||||||
|
foreach (SerializeTest1NetworkBehaviour serverNB in serverNBs)
|
||||||
|
serverNB.ClearAllDirtyBits();
|
||||||
|
|
||||||
|
// INCREMENTAL SYNC INDIVIDUAL
|
||||||
|
for (int i = 0; i < numberOfNBs; i++)
|
||||||
|
{
|
||||||
|
// reset all client nbs
|
||||||
|
foreach (SerializeTest1NetworkBehaviour clientNB in clientNBs)
|
||||||
|
clientNB.value = 0;
|
||||||
|
|
||||||
|
int expected = (i + 1) * 7;
|
||||||
|
|
||||||
|
// set unique value on server components
|
||||||
|
serverNBs[i].value = expected;
|
||||||
|
serverNBs[i].SetDirty();
|
||||||
|
|
||||||
|
ownerWriter.Reset();
|
||||||
|
observersWriter.Reset();
|
||||||
|
// serialize server object
|
||||||
|
serverIdentity.SerializeServer(false, ownerWriter, observersWriter);
|
||||||
|
|
||||||
|
// deserialize client object with OBSERVERS payload
|
||||||
|
reader = new NetworkReader(observersWriter.ToArray());
|
||||||
|
clientIdentity.DeserializeClient(reader, false);
|
||||||
|
for (int index = 0; index < clientNBs.Count; index++)
|
||||||
|
{
|
||||||
|
SerializeTest1NetworkBehaviour clientNB = clientNBs[index];
|
||||||
|
if (index == i)
|
||||||
|
{
|
||||||
|
Assert.That(clientNB.value, Is.EqualTo(expected), $"Expected the clientNBs[{index}] to have a value of {expected}");
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Assert.That(clientNB.value, Is.EqualTo(0), $"Expected the clientNBs[{index}] to have a value of 0 since we're not syncing that index (on sync of #{i})");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// serialization should work even if a component throws an exception.
|
// serialization should work even if a component throws an exception.
|
||||||
// so if first component throws, second should still be serialized fine.
|
// so if first component throws, second should still be serialized fine.
|
||||||
[Test]
|
[Test]
|
||||||
@ -150,20 +268,20 @@ public void TooManyComponents()
|
|||||||
public void ErrorCorrection()
|
public void ErrorCorrection()
|
||||||
{
|
{
|
||||||
int original = 0x12345678;
|
int original = 0x12345678;
|
||||||
byte safety = 0x78; // last byte
|
byte safety = 0x78; // last byte
|
||||||
|
|
||||||
// correct size shouldn't be corrected
|
// correct size shouldn't be corrected
|
||||||
Assert.That(NetworkBehaviour.ErrorCorrection(original + 0, safety), Is.EqualTo(original));
|
Assert.That(NetworkBehaviour.ErrorCorrection(original + 0, safety), Is.EqualTo(original));
|
||||||
|
|
||||||
// read a little too much
|
// read a little too much
|
||||||
Assert.That(NetworkBehaviour.ErrorCorrection(original + 1, safety), Is.EqualTo(original));
|
Assert.That(NetworkBehaviour.ErrorCorrection(original + 1, safety), Is.EqualTo(original));
|
||||||
Assert.That(NetworkBehaviour.ErrorCorrection(original + 2, safety), Is.EqualTo(original));
|
Assert.That(NetworkBehaviour.ErrorCorrection(original + 2, safety), Is.EqualTo(original));
|
||||||
Assert.That(NetworkBehaviour.ErrorCorrection(original + 42, safety), Is.EqualTo(original));
|
Assert.That(NetworkBehaviour.ErrorCorrection(original + 42, safety), Is.EqualTo(original));
|
||||||
|
|
||||||
// read a little too less
|
// read a little too less
|
||||||
Assert.That(NetworkBehaviour.ErrorCorrection(original - 1, safety), Is.EqualTo(original));
|
Assert.That(NetworkBehaviour.ErrorCorrection(original - 1, safety), Is.EqualTo(original));
|
||||||
Assert.That(NetworkBehaviour.ErrorCorrection(original - 2, safety), Is.EqualTo(original));
|
Assert.That(NetworkBehaviour.ErrorCorrection(original - 2, safety), Is.EqualTo(original));
|
||||||
Assert.That(NetworkBehaviour.ErrorCorrection(original - 42, safety), Is.EqualTo(original));
|
Assert.That(NetworkBehaviour.ErrorCorrection(original - 42, safety), Is.EqualTo(original));
|
||||||
|
|
||||||
// reading way too much / less is expected to fail.
|
// reading way too much / less is expected to fail.
|
||||||
// we can only correct the last byte, not more.
|
// we can only correct the last byte, not more.
|
||||||
|
@ -22,7 +22,7 @@ public void Setup()
|
|||||||
GameObject gameObject = new GameObject();
|
GameObject gameObject = new GameObject();
|
||||||
|
|
||||||
encryption = gameObject.AddComponent<EncryptionTransport>();
|
encryption = gameObject.AddComponent<EncryptionTransport>();
|
||||||
encryption.inner = inner;
|
encryption.Inner = inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
[TearDown]
|
[TearDown]
|
||||||
|
@ -17,11 +17,11 @@ public class EncryptionTransportInspector : UnityEditor.Editor
|
|||||||
|
|
||||||
void OnEnable()
|
void OnEnable()
|
||||||
{
|
{
|
||||||
innerProperty = serializedObject.FindProperty("inner");
|
innerProperty = serializedObject.FindProperty("Inner");
|
||||||
clientValidatesServerPubKeyProperty = serializedObject.FindProperty("clientValidateServerPubKey");
|
clientValidatesServerPubKeyProperty = serializedObject.FindProperty("ClientValidateServerPubKey");
|
||||||
clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("clientTrustedPubKeySignatures");
|
clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("ClientTrustedPubKeySignatures");
|
||||||
serverKeypairPathProperty = serializedObject.FindProperty("serverKeypairPath");
|
serverKeypairPathProperty = serializedObject.FindProperty("ServerKeypairPath");
|
||||||
serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("serverLoadKeyPairFromFile");
|
serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("ServerLoadKeyPairFromFile");
|
||||||
}
|
}
|
||||||
|
|
||||||
public override void OnInspectorGUI()
|
public override void OnInspectorGUI()
|
||||||
@ -77,5 +77,8 @@ public override void OnInspectorGUI()
|
|||||||
|
|
||||||
serializedObject.ApplyModifiedProperties();
|
serializedObject.ApplyModifiedProperties();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[CustomEditor(typeof(ThreadedEncryptionTransport), true)]
|
||||||
|
class EncryptionThreadedTransportInspector : EncryptionTransportInspector {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
using System;
|
using System;
|
||||||
using System.Security.Cryptography;
|
using System.Security.Cryptography;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
|
using System.Threading;
|
||||||
using Mirror.BouncyCastle.Crypto;
|
using Mirror.BouncyCastle.Crypto;
|
||||||
using Mirror.BouncyCastle.Crypto.Agreement;
|
using Mirror.BouncyCastle.Crypto.Agreement;
|
||||||
using Mirror.BouncyCastle.Crypto.Digests;
|
using Mirror.BouncyCastle.Crypto.Digests;
|
||||||
@ -14,32 +15,32 @@ namespace Mirror.Transports.Encryption
|
|||||||
public class EncryptedConnection
|
public class EncryptedConnection
|
||||||
{
|
{
|
||||||
// 256-bit key
|
// 256-bit key
|
||||||
private const int KeyLength = 32;
|
const int KeyLength = 32;
|
||||||
// 512-bit salt for the key derivation function
|
// 512-bit salt for the key derivation function
|
||||||
private const int HkdfSaltSize = KeyLength * 2;
|
const int HkdfSaltSize = KeyLength * 2;
|
||||||
|
|
||||||
// Info tag for the HKDF, this just adds more entropy
|
// Info tag for the HKDF, this just adds more entropy
|
||||||
private static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport");
|
static readonly byte[] HkdfInfo = Encoding.UTF8.GetBytes("Mirror/EncryptionTransport");
|
||||||
|
|
||||||
// fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed)
|
// fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed)
|
||||||
private const int NonceSize = 12;
|
const int NonceSize = 12;
|
||||||
|
|
||||||
// this is the size of the "checksum" included in each encrypted payload
|
// this is the size of the "checksum" included in each encrypted payload
|
||||||
// 16 bytes/128 bytes is the recommended value for best security
|
// 16 bytes/128 bytes is the recommended value for best security
|
||||||
// can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker.
|
// can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker.
|
||||||
// Setting it lower than 12 bytes is not recommended
|
// Setting it lower than 12 bytes is not recommended
|
||||||
private const int MacSizeBytes = 16;
|
const int MacSizeBytes = 16;
|
||||||
|
|
||||||
private const int MacSizeBits = MacSizeBytes * 8;
|
const int MacSizeBits = MacSizeBytes * 8;
|
||||||
|
|
||||||
// How much metadata overhead we have for regular packets
|
// How much metadata overhead we have for regular packets
|
||||||
public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize;
|
public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize;
|
||||||
|
|
||||||
// After how many seconds of not receiving a handshake packet we should time out
|
// After how many seconds of not receiving a handshake packet we should time out
|
||||||
private const double DurationTimeout = 2; // 2s
|
const double DurationTimeout = 2; // 2s
|
||||||
|
|
||||||
// After how many seconds to assume the last handshake packet got lost and to resend another one
|
// After how many seconds to assume the last handshake packet got lost and to resend another one
|
||||||
private const double DurationResend = 0.05; // 50ms
|
const double DurationResend = 0.05; // 50ms
|
||||||
|
|
||||||
|
|
||||||
// Static fields for allocation efficiency, makes this not thread safe
|
// Static fields for allocation efficiency, makes this not thread safe
|
||||||
@ -48,20 +49,19 @@ public class EncryptedConnection
|
|||||||
// Set up a global cipher instance, it is initialised/reset before use
|
// Set up a global cipher instance, it is initialised/reset before use
|
||||||
// (AesFastEngine used to exist, but was removed due to side channel issues)
|
// (AesFastEngine used to exist, but was removed due to side channel issues)
|
||||||
// use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core)
|
// use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core)
|
||||||
private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine());
|
static readonly ThreadLocal<GcmBlockCipher> Cipher = new ThreadLocal<GcmBlockCipher>(() => new GcmBlockCipher(AesUtilities.CreateEngine()));
|
||||||
|
|
||||||
// Set up a global HKDF with a SHA-256 digest
|
// Set up a global HKDF with a SHA-256 digest
|
||||||
private static readonly HkdfBytesGenerator Hkdf = new HkdfBytesGenerator(new Sha256Digest());
|
static readonly ThreadLocal<HkdfBytesGenerator> Hkdf = new ThreadLocal<HkdfBytesGenerator>(() => new HkdfBytesGenerator(new Sha256Digest()));
|
||||||
|
|
||||||
// Global byte array to store nonce sent by the remote side, they're used immediately after
|
// Global byte array to store nonce sent by the remote side, they're used immediately after
|
||||||
private static readonly byte[] ReceiveNonce = new byte[NonceSize];
|
static readonly ThreadLocal<byte[]> ReceiveNonce = new ThreadLocal<byte[]>(() => new byte[NonceSize]);
|
||||||
|
|
||||||
// Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes*
|
// Buffer for the remote salt, as bouncycastle needs to take a byte[] *rolls eyes*
|
||||||
private static byte[] _tmpRemoteSaltBuffer = new byte[HkdfSaltSize];
|
static readonly ThreadLocal<byte[]> TMPRemoteSaltBuffer = new ThreadLocal<byte[]>(() => new byte[HkdfSaltSize]);
|
||||||
|
|
||||||
// buffer for encrypt/decrypt operations, resized larger as needed
|
// buffer for encrypt/decrypt operations, resized larger as needed
|
||||||
// this is also the buffer that will be returned to mirror via ArraySegment
|
static ThreadLocal<byte[]> TMPCryptBuffer = new ThreadLocal<byte[]>(() => new byte[2048]);
|
||||||
// so any thread safety concerns would need to take extra care here
|
|
||||||
private static byte[] _tmpCryptBuffer = new byte[2048];
|
|
||||||
|
|
||||||
// packet headers
|
// packet headers
|
||||||
enum OpCodes : byte
|
enum OpCodes : byte
|
||||||
@ -70,7 +70,7 @@ enum OpCodes : byte
|
|||||||
Data = 1,
|
Data = 1,
|
||||||
HandshakeStart = 2,
|
HandshakeStart = 2,
|
||||||
HandshakeAck = 3,
|
HandshakeAck = 3,
|
||||||
HandshakeFin = 4,
|
HandshakeFin = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
enum State
|
enum State
|
||||||
@ -91,37 +91,37 @@ enum State
|
|||||||
Ready
|
Ready
|
||||||
}
|
}
|
||||||
|
|
||||||
private State _state = State.WaitingHandshake;
|
State state = State.WaitingHandshake;
|
||||||
|
|
||||||
// Key exchange confirmed and data can be sent freely
|
// Key exchange confirmed and data can be sent freely
|
||||||
public bool IsReady => _state == State.Ready;
|
public bool IsReady => state == State.Ready;
|
||||||
// Callback to send off encrypted data
|
// Callback to send off encrypted data
|
||||||
private Action<ArraySegment<byte>, int> _send;
|
readonly Action<ArraySegment<byte>, int> send;
|
||||||
// Callback when received data has been decrypted
|
// Callback when received data has been decrypted
|
||||||
private Action<ArraySegment<byte>, int> _receive;
|
readonly Action<ArraySegment<byte>, int> receive;
|
||||||
// Callback when the connection becomes ready
|
// Callback when the connection becomes ready
|
||||||
private Action _ready;
|
readonly Action ready;
|
||||||
// On-error callback, disconnect expected
|
// On-error callback, disconnect expected
|
||||||
private Action<TransportError, string> _error;
|
readonly Action<TransportError, string> error;
|
||||||
// Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance
|
// Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance
|
||||||
// (usually client validates the server key)
|
// (usually client validates the server key)
|
||||||
private Func<PubKeyInfo, bool> _validateRemoteKey;
|
readonly Func<PubKeyInfo, bool> validateRemoteKey;
|
||||||
// Our asymmetric credentials for the initial DH exchange
|
// Our asymmetric credentials for the initial DH exchange
|
||||||
private EncryptionCredentials _credentials;
|
EncryptionCredentials credentials;
|
||||||
private byte[] _hkdfSalt;
|
readonly byte[] hkdfSalt;
|
||||||
|
|
||||||
// After no handshake packet in this many seconds, the handshake fails
|
// After no handshake packet in this many seconds, the handshake fails
|
||||||
private double _handshakeTimeout;
|
double handshakeTimeout;
|
||||||
// When to assume the last handshake packet got lost and to resend another one
|
// When to assume the last handshake packet got lost and to resend another one
|
||||||
private double _nextHandshakeResend;
|
double nextHandshakeResend;
|
||||||
|
|
||||||
|
|
||||||
// we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in
|
// we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in
|
||||||
// so we can update it without creating a new AeadParameters instance
|
// so we can update it without creating a new AeadParameters instance
|
||||||
// this might break in the future! (will cause bad data)
|
// this might break in the future! (will cause bad data)
|
||||||
private byte[] _nonce = new byte[NonceSize];
|
byte[] nonce = new byte[NonceSize];
|
||||||
private AeadParameters _cipherParametersEncrypt;
|
AeadParameters cipherParametersEncrypt;
|
||||||
private AeadParameters _cipherParametersDecrypt;
|
AeadParameters cipherParametersDecrypt;
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -130,7 +130,7 @@ enum State
|
|||||||
*
|
*
|
||||||
* The client does this, since the fin is not acked explicitly, but by receiving data to decrypt
|
* The client does this, since the fin is not acked explicitly, but by receiving data to decrypt
|
||||||
*/
|
*/
|
||||||
private readonly bool _sendsFirst;
|
readonly bool sendsFirst;
|
||||||
|
|
||||||
public EncryptedConnection(EncryptionCredentials credentials,
|
public EncryptedConnection(EncryptionCredentials credentials,
|
||||||
bool isClient,
|
bool isClient,
|
||||||
@ -140,28 +140,24 @@ public EncryptedConnection(EncryptionCredentials credentials,
|
|||||||
Action<TransportError, string> errorAction,
|
Action<TransportError, string> errorAction,
|
||||||
Func<PubKeyInfo, bool> validateRemoteKey = null)
|
Func<PubKeyInfo, bool> validateRemoteKey = null)
|
||||||
{
|
{
|
||||||
_credentials = credentials;
|
this.credentials = credentials;
|
||||||
_sendsFirst = isClient;
|
sendsFirst = isClient;
|
||||||
if (!_sendsFirst)
|
if (!sendsFirst)
|
||||||
{
|
|
||||||
// salt is controlled by the server
|
// salt is controlled by the server
|
||||||
_hkdfSalt = GenerateSecureBytes(HkdfSaltSize);
|
hkdfSalt = GenerateSecureBytes(HkdfSaltSize);
|
||||||
}
|
send = sendAction;
|
||||||
_send = sendAction;
|
receive = receiveAction;
|
||||||
_receive = receiveAction;
|
ready = readyAction;
|
||||||
_ready = readyAction;
|
error = errorAction;
|
||||||
_error = errorAction;
|
this.validateRemoteKey = validateRemoteKey;
|
||||||
_validateRemoteKey = validateRemoteKey;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates a random starting nonce
|
// Generates a random starting nonce
|
||||||
private static byte[] GenerateSecureBytes(int size)
|
static byte[] GenerateSecureBytes(int size)
|
||||||
{
|
{
|
||||||
byte[] bytes = new byte[size];
|
byte[] bytes = new byte[size];
|
||||||
using (RandomNumberGenerator rng = RandomNumberGenerator.Create())
|
using (RandomNumberGenerator rng = RandomNumberGenerator.Create())
|
||||||
{
|
|
||||||
rng.GetBytes(bytes);
|
rng.GetBytes(bytes);
|
||||||
}
|
|
||||||
|
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
@ -170,7 +166,7 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
|
|||||||
{
|
{
|
||||||
if (data.Count < 1)
|
if (data.Count < 1)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, "Received empty packet");
|
error(TransportError.Unexpected, "Received empty packet");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,94 +177,80 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
|
|||||||
{
|
{
|
||||||
case OpCodes.Data:
|
case OpCodes.Data:
|
||||||
// first sender ready is implicit when data is received
|
// first sender ready is implicit when data is received
|
||||||
if (_sendsFirst && _state == State.WaitingHandshakeReply)
|
if (sendsFirst && state == State.WaitingHandshakeReply)
|
||||||
{
|
|
||||||
SetReady();
|
SetReady();
|
||||||
}
|
|
||||||
else if (!IsReady)
|
else if (!IsReady)
|
||||||
{
|
error(TransportError.Unexpected, "Unexpected data while not ready.");
|
||||||
_error(TransportError.Unexpected, "Unexpected data while not ready.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (reader.Remaining < Overhead)
|
if (reader.Remaining < Overhead)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, "received data packet smaller than metadata size");
|
error(TransportError.Unexpected, "received data packet smaller than metadata size");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ArraySegment<byte> ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize);
|
ArraySegment<byte> ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize);
|
||||||
reader.ReadBytes(ReceiveNonce, NonceSize);
|
reader.ReadBytes(ReceiveNonce.Value, NonceSize);
|
||||||
|
|
||||||
Profiler.BeginSample("EncryptedConnection.Decrypt");
|
Profiler.BeginSample("EncryptedConnection.Decrypt");
|
||||||
ArraySegment<byte> plaintext = Decrypt(ciphertext);
|
ArraySegment<byte> plaintext = Decrypt(ciphertext);
|
||||||
Profiler.EndSample();
|
Profiler.EndSample();
|
||||||
if (plaintext.Count == 0)
|
if (plaintext.Count == 0)
|
||||||
{
|
|
||||||
// error
|
// error
|
||||||
return;
|
return;
|
||||||
}
|
receive(plaintext, channel);
|
||||||
_receive(plaintext, channel);
|
|
||||||
break;
|
break;
|
||||||
case OpCodes.HandshakeStart:
|
case OpCodes.HandshakeStart:
|
||||||
if (_sendsFirst)
|
if (sendsFirst)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this.");
|
error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_state == State.WaitingHandshakeReply)
|
if (state == State.WaitingHandshakeReply)
|
||||||
{
|
|
||||||
// this is fine, packets may arrive out of order
|
// this is fine, packets may arrive out of order
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
_state = State.WaitingHandshakeReply;
|
state = State.WaitingHandshakeReply;
|
||||||
ResetTimeouts();
|
ResetTimeouts();
|
||||||
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _hkdfSalt);
|
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), hkdfSalt);
|
||||||
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
|
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
|
||||||
break;
|
break;
|
||||||
case OpCodes.HandshakeAck:
|
case OpCodes.HandshakeAck:
|
||||||
if (!_sendsFirst)
|
if (!sendsFirst)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this.");
|
error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsReady)
|
if (IsReady)
|
||||||
{
|
|
||||||
// this is fine, packets may arrive out of order
|
// this is fine, packets may arrive out of order
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
if (_state == State.WaitingHandshakeReply)
|
if (state == State.WaitingHandshakeReply)
|
||||||
{
|
|
||||||
// this is fine, packets may arrive out of order
|
// this is fine, packets may arrive out of order
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_state = State.WaitingHandshakeReply;
|
state = State.WaitingHandshakeReply;
|
||||||
ResetTimeouts();
|
ResetTimeouts();
|
||||||
reader.ReadBytes(_tmpRemoteSaltBuffer, HkdfSaltSize);
|
reader.ReadBytes(TMPRemoteSaltBuffer.Value, HkdfSaltSize);
|
||||||
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), _tmpRemoteSaltBuffer);
|
CompleteExchange(reader.ReadBytesSegment(reader.Remaining), TMPRemoteSaltBuffer.Value);
|
||||||
SendHandshakeFin();
|
SendHandshakeFin();
|
||||||
break;
|
break;
|
||||||
case OpCodes.HandshakeFin:
|
case OpCodes.HandshakeFin:
|
||||||
if (_sendsFirst)
|
if (sendsFirst)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this.");
|
error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsReady)
|
if (IsReady)
|
||||||
{
|
|
||||||
// this is fine, packets may arrive out of order
|
// this is fine, packets may arrive out of order
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
if (_state != State.WaitingHandshakeReply)
|
if (state != State.WaitingHandshakeReply)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected,
|
error(TransportError.Unexpected,
|
||||||
"Received HandshakeFin packet, we didn't expect this yet.");
|
"Received HandshakeFin packet, we didn't expect this yet.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -277,24 +259,25 @@ public void OnReceiveRaw(ArraySegment<byte> data, int channel)
|
|||||||
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
_error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}");
|
error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private void SetReady()
|
|
||||||
|
void SetReady()
|
||||||
{
|
{
|
||||||
// done with credentials, null out the reference
|
// done with credentials, null out the reference
|
||||||
_credentials = null;
|
credentials = null;
|
||||||
|
|
||||||
_state = State.Ready;
|
state = State.Ready;
|
||||||
_ready();
|
ready();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void ResetTimeouts()
|
void ResetTimeouts()
|
||||||
{
|
{
|
||||||
_handshakeTimeout = 0;
|
handshakeTimeout = 0;
|
||||||
_nextHandshakeResend = -1;
|
nextHandshakeResend = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void Send(ArraySegment<byte> data, int channel)
|
public void Send(ArraySegment<byte> data, int channel)
|
||||||
@ -307,161 +290,149 @@ public void Send(ArraySegment<byte> data, int channel)
|
|||||||
Profiler.EndSample();
|
Profiler.EndSample();
|
||||||
|
|
||||||
if (encrypted.Count == 0)
|
if (encrypted.Count == 0)
|
||||||
{
|
|
||||||
// error
|
// error
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
writer.WriteBytes(encrypted.Array, 0, encrypted.Count);
|
writer.WriteBytes(encrypted.Array, 0, encrypted.Count);
|
||||||
// write nonce after since Encrypt will update it
|
// write nonce after since Encrypt will update it
|
||||||
writer.WriteBytes(_nonce, 0, NonceSize);
|
writer.WriteBytes(nonce, 0, NonceSize);
|
||||||
_send(writer.ToArraySegment(), channel);
|
send(writer.ToArraySegment(), channel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private ArraySegment<byte> Encrypt(ArraySegment<byte> plaintext)
|
ArraySegment<byte> Encrypt(ArraySegment<byte> plaintext)
|
||||||
{
|
{
|
||||||
if (plaintext.Count == 0)
|
if (plaintext.Count == 0)
|
||||||
{
|
|
||||||
// Invalid
|
// Invalid
|
||||||
return new ArraySegment<byte>();
|
return new ArraySegment<byte>();
|
||||||
}
|
|
||||||
// Need to make the nonce unique again before encrypting another message
|
// Need to make the nonce unique again before encrypting another message
|
||||||
UpdateNonce();
|
UpdateNonce();
|
||||||
// Re-initialize the cipher with our cached parameters
|
// Re-initialize the cipher with our cached parameters
|
||||||
Cipher.Init(true, _cipherParametersEncrypt);
|
Cipher.Value.Init(true, cipherParametersEncrypt);
|
||||||
|
|
||||||
// Calculate the expected output size, this should always be input size + mac size
|
// Calculate the expected output size, this should always be input size + mac size
|
||||||
int outSize = Cipher.GetOutputSize(plaintext.Count);
|
int outSize = Cipher.Value.GetOutputSize(plaintext.Count);
|
||||||
#if UNITY_EDITOR
|
#if UNITY_EDITOR
|
||||||
// expecting the outSize to be input size + MacSize
|
// expecting the outSize to be input size + MacSize
|
||||||
if (outSize != plaintext.Count + MacSizeBytes)
|
if (outSize != plaintext.Count + MacSizeBytes)
|
||||||
{
|
|
||||||
throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}");
|
throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}");
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
// Resize the static buffer to fit
|
// Resize the static buffer to fit
|
||||||
EnsureSize(ref _tmpCryptBuffer, outSize);
|
byte[] cryptBuffer = TMPCryptBuffer.Value;
|
||||||
|
EnsureSize(ref cryptBuffer, outSize);
|
||||||
|
TMPCryptBuffer.Value = cryptBuffer;
|
||||||
|
|
||||||
int resultLen;
|
int resultLen;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// Run the plain text through the cipher, ProcessBytes will only process full blocks
|
// Run the plain text through the cipher, ProcessBytes will only process full blocks
|
||||||
resultLen =
|
resultLen =
|
||||||
Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, _tmpCryptBuffer, 0);
|
Cipher.Value.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, cryptBuffer, 0);
|
||||||
// Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac)
|
// Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac)
|
||||||
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen);
|
resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen);
|
||||||
}
|
}
|
||||||
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
|
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
|
||||||
//
|
//
|
||||||
catch (Exception e)
|
catch (Exception e)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}");
|
error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}");
|
||||||
return new ArraySegment<byte>();
|
return new ArraySegment<byte>();
|
||||||
}
|
}
|
||||||
#if UNITY_EDITOR
|
#if UNITY_EDITOR
|
||||||
// expecting the result length to match the previously calculated input size + MacSize
|
// expecting the result length to match the previously calculated input size + MacSize
|
||||||
if (resultLen != outSize)
|
if (resultLen != outSize)
|
||||||
{
|
|
||||||
throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
|
throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen);
|
return new ArraySegment<byte>(cryptBuffer, 0, resultLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ArraySegment<byte> Decrypt(ArraySegment<byte> ciphertext)
|
ArraySegment<byte> Decrypt(ArraySegment<byte> ciphertext)
|
||||||
{
|
{
|
||||||
if (ciphertext.Count <= MacSizeBytes)
|
if (ciphertext.Count <= MacSizeBytes)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})");
|
error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})");
|
||||||
// Invalid
|
// Invalid
|
||||||
return new ArraySegment<byte>();
|
return new ArraySegment<byte>();
|
||||||
}
|
}
|
||||||
// Re-initialize the cipher with our cached parameters
|
// Re-initialize the cipher with our cached parameters
|
||||||
Cipher.Init(false, _cipherParametersDecrypt);
|
Cipher.Value.Init(false, cipherParametersDecrypt);
|
||||||
|
|
||||||
// Calculate the expected output size, this should always be input size - mac size
|
// Calculate the expected output size, this should always be input size - mac size
|
||||||
int outSize = Cipher.GetOutputSize(ciphertext.Count);
|
int outSize = Cipher.Value.GetOutputSize(ciphertext.Count);
|
||||||
#if UNITY_EDITOR
|
#if UNITY_EDITOR
|
||||||
// expecting the outSize to be input size - MacSize
|
// expecting the outSize to be input size - MacSize
|
||||||
if (outSize != ciphertext.Count - MacSizeBytes)
|
if (outSize != ciphertext.Count - MacSizeBytes)
|
||||||
{
|
|
||||||
throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}");
|
throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}");
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
// Resize the static buffer to fit
|
|
||||||
EnsureSize(ref _tmpCryptBuffer, outSize);
|
byte[] cryptBuffer = TMPCryptBuffer.Value;
|
||||||
|
EnsureSize(ref cryptBuffer, outSize);
|
||||||
|
TMPCryptBuffer.Value = cryptBuffer;
|
||||||
|
|
||||||
int resultLen;
|
int resultLen;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// Run the ciphertext through the cipher, ProcessBytes will only process full blocks
|
// Run the ciphertext through the cipher, ProcessBytes will only process full blocks
|
||||||
resultLen =
|
resultLen =
|
||||||
Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, _tmpCryptBuffer, 0);
|
Cipher.Value.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, cryptBuffer, 0);
|
||||||
// Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac)
|
// Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac)
|
||||||
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen);
|
resultLen += Cipher.Value.DoFinal(cryptBuffer, resultLen);
|
||||||
}
|
}
|
||||||
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
|
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
|
||||||
catch (Exception e)
|
catch (Exception e)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data");
|
error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data");
|
||||||
return new ArraySegment<byte>();
|
return new ArraySegment<byte>();
|
||||||
}
|
}
|
||||||
#if UNITY_EDITOR
|
#if UNITY_EDITOR
|
||||||
// expecting the result length to match the previously calculated input size + MacSize
|
// expecting the result length to match the previously calculated input size + MacSize
|
||||||
if (resultLen != outSize)
|
if (resultLen != outSize)
|
||||||
{
|
|
||||||
throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
|
throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen);
|
return new ArraySegment<byte>(cryptBuffer, 0, resultLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void UpdateNonce()
|
void UpdateNonce()
|
||||||
{
|
{
|
||||||
// increment the nonce by one
|
// increment the nonce by one
|
||||||
// we need to ensure the nonce is *always* unique and not reused
|
// we need to ensure the nonce is *always* unique and not reused
|
||||||
// easiest way to do this is by simply incrementing it
|
// easiest way to do this is by simply incrementing it
|
||||||
for (int i = 0; i < NonceSize; i++)
|
for (int i = 0; i < NonceSize; i++)
|
||||||
{
|
{
|
||||||
_nonce[i]++;
|
nonce[i]++;
|
||||||
if (_nonce[i] != 0)
|
if (nonce[i] != 0)
|
||||||
{
|
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void EnsureSize(ref byte[] buffer, int size)
|
static void EnsureSize(ref byte[] buffer, int size)
|
||||||
{
|
{
|
||||||
if (buffer.Length < size)
|
if (buffer.Length < size)
|
||||||
{
|
|
||||||
// double buffer to avoid constantly resizing by a few bytes
|
// double buffer to avoid constantly resizing by a few bytes
|
||||||
Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2));
|
Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void SendHandshakeAndPubKey(OpCodes opcode)
|
void SendHandshakeAndPubKey(OpCodes opcode)
|
||||||
{
|
{
|
||||||
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
|
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
|
||||||
{
|
{
|
||||||
writer.WriteByte((byte)opcode);
|
writer.WriteByte((byte)opcode);
|
||||||
if (opcode == OpCodes.HandshakeAck)
|
if (opcode == OpCodes.HandshakeAck)
|
||||||
{
|
writer.WriteBytes(hkdfSalt, 0, HkdfSaltSize);
|
||||||
writer.WriteBytes(_hkdfSalt, 0, HkdfSaltSize);
|
writer.WriteBytes(credentials.PublicKeySerialized, 0, credentials.PublicKeySerialized.Length);
|
||||||
}
|
send(writer.ToArraySegment(), Channels.Unreliable);
|
||||||
writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length);
|
|
||||||
_send(writer.ToArraySegment(), Channels.Unreliable);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void SendHandshakeFin()
|
void SendHandshakeFin()
|
||||||
{
|
{
|
||||||
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
|
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
|
||||||
{
|
{
|
||||||
writer.WriteByte((byte)OpCodes.HandshakeFin);
|
writer.WriteByte((byte)OpCodes.HandshakeFin);
|
||||||
_send(writer.ToArraySegment(), Channels.Unreliable);
|
send(writer.ToArraySegment(), Channels.Unreliable);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
|
void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
|
||||||
{
|
{
|
||||||
AsymmetricKeyParameter remotePubKey;
|
AsymmetricKeyParameter remotePubKey;
|
||||||
try
|
try
|
||||||
@ -470,11 +441,11 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
|
|||||||
}
|
}
|
||||||
catch (Exception e)
|
catch (Exception e)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}");
|
error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_validateRemoteKey != null)
|
if (validateRemoteKey != null)
|
||||||
{
|
{
|
||||||
PubKeyInfo info = new PubKeyInfo
|
PubKeyInfo info = new PubKeyInfo
|
||||||
{
|
{
|
||||||
@ -482,9 +453,9 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
|
|||||||
Serialized = remotePubKeyRaw,
|
Serialized = remotePubKeyRaw,
|
||||||
Key = remotePubKey
|
Key = remotePubKey
|
||||||
};
|
};
|
||||||
if (!_validateRemoteKey(info))
|
if (!validateRemoteKey(info))
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. ");
|
error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. ");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -493,7 +464,7 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
|
|||||||
// This gives us the same key on the other side, with our public key and their remote
|
// This gives us the same key on the other side, with our public key and their remote
|
||||||
// It's like magic, but with math!
|
// It's like magic, but with math!
|
||||||
ECDHBasicAgreement ecdh = new ECDHBasicAgreement();
|
ECDHBasicAgreement ecdh = new ECDHBasicAgreement();
|
||||||
ecdh.Init(_credentials.PrivateKey);
|
ecdh.Init(credentials.PrivateKey);
|
||||||
byte[] sharedSecret;
|
byte[] sharedSecret;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
@ -502,33 +473,33 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
|
|||||||
catch
|
catch
|
||||||
(Exception e)
|
(Exception e)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}");
|
error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (salt.Length != HkdfSaltSize)
|
if (salt.Length != HkdfSaltSize)
|
||||||
{
|
{
|
||||||
_error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}.");
|
error(TransportError.Unexpected, $"Salt is expected to be {HkdfSaltSize} bytes long, got {salt.Length}.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Hkdf.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo));
|
Hkdf.Value.Init(new HkdfParameters(sharedSecret, salt, HkdfInfo));
|
||||||
|
|
||||||
// Allocate a buffer for the output key
|
// Allocate a buffer for the output key
|
||||||
byte[] keyRaw = new byte[KeyLength];
|
byte[] keyRaw = new byte[KeyLength];
|
||||||
|
|
||||||
// Generate the output keying material
|
// Generate the output keying material
|
||||||
Hkdf.GenerateBytes(keyRaw, 0, keyRaw.Length);
|
Hkdf.Value.GenerateBytes(keyRaw, 0, keyRaw.Length);
|
||||||
|
|
||||||
KeyParameter key = new KeyParameter(keyRaw);
|
KeyParameter key = new KeyParameter(keyRaw);
|
||||||
|
|
||||||
// generate a starting nonce
|
// generate a starting nonce
|
||||||
_nonce = GenerateSecureBytes(NonceSize);
|
nonce = GenerateSecureBytes(NonceSize);
|
||||||
|
|
||||||
// we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance
|
// we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance
|
||||||
// instead of creating a new one each encrypt/decrypt
|
// instead of creating a new one each encrypt/decrypt
|
||||||
_cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, _nonce);
|
cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, nonce);
|
||||||
_cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce);
|
cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce.Value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -537,53 +508,41 @@ private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw, byte[] salt)
|
|||||||
public void TickNonReady(double time)
|
public void TickNonReady(double time)
|
||||||
{
|
{
|
||||||
if (IsReady)
|
if (IsReady)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Timeout reset
|
||||||
|
if (handshakeTimeout == 0)
|
||||||
|
handshakeTimeout = time + DurationTimeout;
|
||||||
|
else if (time > handshakeTimeout)
|
||||||
{
|
{
|
||||||
|
error?.Invoke(TransportError.Timeout, $"Timed out during {state}, this probably just means the other side went away which is fine.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout reset
|
// Timeout reset
|
||||||
if (_handshakeTimeout == 0)
|
if (nextHandshakeResend < 0)
|
||||||
{
|
{
|
||||||
_handshakeTimeout = time + DurationTimeout;
|
nextHandshakeResend = time + DurationResend;
|
||||||
}
|
|
||||||
else if (time > _handshakeTimeout)
|
|
||||||
{
|
|
||||||
_error?.Invoke(TransportError.Timeout, $"Timed out during {_state}, this probably just means the other side went away which is fine.");
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout reset
|
if (time < nextHandshakeResend)
|
||||||
if (_nextHandshakeResend < 0)
|
|
||||||
{
|
|
||||||
_nextHandshakeResend = time + DurationResend;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (time < _nextHandshakeResend)
|
|
||||||
{
|
|
||||||
// Resend isn't due yet
|
// Resend isn't due yet
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
_nextHandshakeResend = time + DurationResend;
|
nextHandshakeResend = time + DurationResend;
|
||||||
switch (_state)
|
switch (state)
|
||||||
{
|
{
|
||||||
case State.WaitingHandshake:
|
case State.WaitingHandshake:
|
||||||
if (_sendsFirst)
|
if (sendsFirst)
|
||||||
{
|
|
||||||
SendHandshakeAndPubKey(OpCodes.HandshakeStart);
|
SendHandshakeAndPubKey(OpCodes.HandshakeStart);
|
||||||
}
|
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case State.WaitingHandshakeReply:
|
case State.WaitingHandshakeReply:
|
||||||
if (_sendsFirst)
|
if (sendsFirst)
|
||||||
{
|
|
||||||
SendHandshakeFin();
|
SendHandshakeFin();
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
|
||||||
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
|
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
|
||||||
}
|
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case State.Ready: // IsReady is checked above & early-returned
|
case State.Ready: // IsReady is checked above & early-returned
|
||||||
|
@ -51,13 +51,11 @@ public static byte[] SerializePublicKey(AsymmetricKeyParameter publicKey)
|
|||||||
return publicKeyInfo.ToAsn1Object().GetDerEncoded();
|
return publicKeyInfo.ToAsn1Object().GetDerEncoded();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment<byte> pubKey)
|
public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment<byte> pubKey) =>
|
||||||
{
|
|
||||||
// And then we do this to deserialize from the DER (from above)
|
// And then we do this to deserialize from the DER (from above)
|
||||||
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
|
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
|
||||||
// to a byte[] first and then shoved through a MemoryStream
|
// to a byte[] first and then shoved through a MemoryStream
|
||||||
return PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false));
|
PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false));
|
||||||
}
|
|
||||||
|
|
||||||
public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
|
public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
|
||||||
{
|
{
|
||||||
@ -66,13 +64,11 @@ public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
|
|||||||
return privateKeyInfo.ToAsn1Object().GetDerEncoded();
|
return privateKeyInfo.ToAsn1Object().GetDerEncoded();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment<byte> privateKey)
|
public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment<byte> privateKey) =>
|
||||||
{
|
|
||||||
// And then we do this to deserialize from the DER (from above)
|
// And then we do this to deserialize from the DER (from above)
|
||||||
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
|
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
|
||||||
// to a byte[] first and then shoved through a MemoryStream
|
// to a byte[] first and then shoved through a MemoryStream
|
||||||
return PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false));
|
PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false));
|
||||||
}
|
|
||||||
|
|
||||||
public static string PubKeyFingerprint(ArraySegment<byte> publicKeyBytes)
|
public static string PubKeyFingerprint(ArraySegment<byte> publicKeyBytes)
|
||||||
{
|
{
|
||||||
@ -90,7 +86,7 @@ public void SaveToFile(string path)
|
|||||||
{
|
{
|
||||||
PublicKeyFingerprint = PublicKeyFingerprint,
|
PublicKeyFingerprint = PublicKeyFingerprint,
|
||||||
PublicKey = Convert.ToBase64String(PublicKeySerialized),
|
PublicKey = Convert.ToBase64String(PublicKeySerialized),
|
||||||
PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)),
|
PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey))
|
||||||
});
|
});
|
||||||
File.WriteAllText(path, json);
|
File.WriteAllText(path, json);
|
||||||
}
|
}
|
||||||
@ -104,9 +100,7 @@ public static EncryptionCredentials LoadFromFile(string path)
|
|||||||
byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey);
|
byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey);
|
||||||
|
|
||||||
if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment<byte>(publicKeyBytes)))
|
if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment<byte>(publicKeyBytes)))
|
||||||
{
|
|
||||||
throw new Exception("Saved public key fingerprint does not match public key.");
|
throw new Exception("Saved public key fingerprint does not match public key.");
|
||||||
}
|
|
||||||
return new EncryptionCredentials
|
return new EncryptionCredentials
|
||||||
{
|
{
|
||||||
PublicKeySerialized = publicKeyBytes,
|
PublicKeySerialized = publicKeyBytes,
|
||||||
@ -115,7 +109,7 @@ public static EncryptionCredentials LoadFromFile(string path)
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private class SerializedPair
|
class SerializedPair
|
||||||
{
|
{
|
||||||
public string PublicKeyFingerprint;
|
public string PublicKeyFingerprint;
|
||||||
public string PublicKey;
|
public string PublicKey;
|
||||||
|
@ -12,28 +12,27 @@ public class EncryptionTransport : Transport, PortTransport
|
|||||||
{
|
{
|
||||||
public override bool IsEncrypted => true;
|
public override bool IsEncrypted => true;
|
||||||
public override string EncryptionCipher => "AES256-GCM";
|
public override string EncryptionCipher => "AES256-GCM";
|
||||||
public Transport inner;
|
[FormerlySerializedAs("inner")]
|
||||||
|
public Transport Inner;
|
||||||
|
|
||||||
public ushort Port
|
public ushort Port
|
||||||
{
|
{
|
||||||
get
|
get
|
||||||
{
|
{
|
||||||
if (inner is PortTransport portTransport)
|
if (Inner is PortTransport portTransport)
|
||||||
{
|
|
||||||
return portTransport.Port;
|
return portTransport.Port;
|
||||||
}
|
|
||||||
|
|
||||||
Debug.LogError($"EncryptionTransport can't get Port because {inner} is not a PortTransport");
|
Debug.LogError($"EncryptionTransport can't get Port because {Inner} is not a PortTransport");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
set
|
set
|
||||||
{
|
{
|
||||||
if (inner is PortTransport portTransport)
|
if (Inner is PortTransport portTransport)
|
||||||
{
|
{
|
||||||
portTransport.Port = value;
|
portTransport.Port = value;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Debug.LogError($"EncryptionTransport can't set Port because {inner} is not a PortTransport");
|
Debug.LogError($"EncryptionTransport can't set Port because {Inner} is not a PortTransport");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,81 +40,78 @@ public enum ValidationMode
|
|||||||
{
|
{
|
||||||
Off,
|
Off,
|
||||||
List,
|
List,
|
||||||
Callback,
|
Callback
|
||||||
}
|
}
|
||||||
|
|
||||||
public ValidationMode clientValidateServerPubKey;
|
[FormerlySerializedAs("clientValidateServerPubKey")]
|
||||||
|
public ValidationMode ClientValidateServerPubKey;
|
||||||
|
[FormerlySerializedAs("clientTrustedPubKeySignatures")]
|
||||||
[Tooltip("List of public key fingerprints the client will accept")]
|
[Tooltip("List of public key fingerprints the client will accept")]
|
||||||
public string[] clientTrustedPubKeySignatures;
|
public string[] ClientTrustedPubKeySignatures;
|
||||||
public Func<PubKeyInfo, bool> onClientValidateServerPubKey;
|
public Func<PubKeyInfo, bool> OnClientValidateServerPubKey;
|
||||||
public bool serverLoadKeyPairFromFile;
|
[FormerlySerializedAs("serverLoadKeyPairFromFile")]
|
||||||
public string serverKeypairPath = "./server-keys.json";
|
public bool ServerLoadKeyPairFromFile;
|
||||||
|
[FormerlySerializedAs("serverKeypairPath")]
|
||||||
|
public string ServerKeypairPath = "./server-keys.json";
|
||||||
|
|
||||||
private EncryptedConnection _client;
|
EncryptedConnection client;
|
||||||
|
|
||||||
private Dictionary<int, EncryptedConnection> _serverConnections = new Dictionary<int, EncryptedConnection>();
|
readonly Dictionary<int, EncryptedConnection> serverConnections = new Dictionary<int, EncryptedConnection>();
|
||||||
|
|
||||||
private List<EncryptedConnection> _serverPendingConnections =
|
readonly List<EncryptedConnection> serverPendingConnections =
|
||||||
new List<EncryptedConnection>();
|
new List<EncryptedConnection>();
|
||||||
|
|
||||||
private EncryptionCredentials _credentials;
|
EncryptionCredentials credentials;
|
||||||
public string EncryptionPublicKeyFingerprint => _credentials?.PublicKeyFingerprint;
|
public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint;
|
||||||
public byte[] EncryptionPublicKey => _credentials?.PublicKeySerialized;
|
public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized;
|
||||||
|
|
||||||
private void ServerRemoveFromPending(EncryptedConnection con)
|
void ServerRemoveFromPending(EncryptedConnection con)
|
||||||
{
|
{
|
||||||
for (int i = 0; i < _serverPendingConnections.Count; i++)
|
for (int i = 0; i < serverPendingConnections.Count; i++)
|
||||||
{
|
if (serverPendingConnections[i] == con)
|
||||||
if (_serverPendingConnections[i] == con)
|
|
||||||
{
|
{
|
||||||
// remove by swapping with last
|
// remove by swapping with last
|
||||||
int lastIndex = _serverPendingConnections.Count - 1;
|
int lastIndex = serverPendingConnections.Count - 1;
|
||||||
_serverPendingConnections[i] = _serverPendingConnections[lastIndex];
|
serverPendingConnections[i] = serverPendingConnections[lastIndex];
|
||||||
_serverPendingConnections.RemoveAt(lastIndex);
|
serverPendingConnections.RemoveAt(lastIndex);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void HandleInnerServerDisconnected(int connId)
|
void HandleInnerServerDisconnected(int connId)
|
||||||
{
|
{
|
||||||
if (_serverConnections.TryGetValue(connId, out EncryptedConnection con))
|
if (serverConnections.TryGetValue(connId, out EncryptedConnection con))
|
||||||
{
|
{
|
||||||
ServerRemoveFromPending(con);
|
ServerRemoveFromPending(con);
|
||||||
_serverConnections.Remove(connId);
|
serverConnections.Remove(connId);
|
||||||
}
|
}
|
||||||
OnServerDisconnected?.Invoke(connId);
|
OnServerDisconnected?.Invoke(connId);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void HandleInnerServerError(int connId, TransportError type, string msg)
|
void HandleInnerServerError(int connId, TransportError type, string msg) => OnServerError?.Invoke(connId, type, $"inner: {msg}");
|
||||||
{
|
|
||||||
OnServerError?.Invoke(connId, type, $"inner: {msg}");
|
|
||||||
}
|
|
||||||
|
|
||||||
private void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
|
void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
|
||||||
{
|
{
|
||||||
if (_serverConnections.TryGetValue(connId, out EncryptedConnection c))
|
if (serverConnections.TryGetValue(connId, out EncryptedConnection c))
|
||||||
{
|
|
||||||
c.OnReceiveRaw(data, channel);
|
c.OnReceiveRaw(data, channel);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, inner.ServerGetClientAddress(connId));
|
void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId));
|
||||||
|
|
||||||
private void HandleInnerServerConnected(int connId, string clientRemoteAddress)
|
void HandleInnerServerConnected(int connId, string clientRemoteAddress)
|
||||||
{
|
{
|
||||||
Debug.Log($"[EncryptionTransport] New connection #{connId}");
|
Debug.Log($"[EncryptionTransport] New connection #{connId} from {clientRemoteAddress}");
|
||||||
EncryptedConnection ec = null;
|
EncryptedConnection ec = null;
|
||||||
ec = new EncryptedConnection(
|
ec = new EncryptedConnection(
|
||||||
_credentials,
|
credentials,
|
||||||
false,
|
false,
|
||||||
(segment, channel) => inner.ServerSend(connId, segment, channel),
|
(segment, channel) => Inner.ServerSend(connId, segment, channel),
|
||||||
(segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel),
|
(segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel),
|
||||||
() =>
|
() =>
|
||||||
{
|
{
|
||||||
Debug.Log($"[EncryptionTransport] Connection #{connId} is ready");
|
Debug.Log($"[EncryptionTransport] Connection #{connId} is ready");
|
||||||
|
// ReSharper disable once AccessToModifiedClosure
|
||||||
ServerRemoveFromPending(ec);
|
ServerRemoveFromPending(ec);
|
||||||
//OnServerConnected?.Invoke(connId);
|
|
||||||
OnServerConnectedWithAddress?.Invoke(connId, clientRemoteAddress);
|
OnServerConnectedWithAddress?.Invoke(connId, clientRemoteAddress);
|
||||||
},
|
},
|
||||||
(type, msg) =>
|
(type, msg) =>
|
||||||
@ -123,32 +119,25 @@ private void HandleInnerServerConnected(int connId, string clientRemoteAddress)
|
|||||||
OnServerError?.Invoke(connId, type, msg);
|
OnServerError?.Invoke(connId, type, msg);
|
||||||
ServerDisconnect(connId);
|
ServerDisconnect(connId);
|
||||||
});
|
});
|
||||||
_serverConnections.Add(connId, ec);
|
serverConnections.Add(connId, ec);
|
||||||
_serverPendingConnections.Add(ec);
|
serverPendingConnections.Add(ec);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void HandleInnerClientDisconnected()
|
void HandleInnerClientDisconnected()
|
||||||
{
|
{
|
||||||
_client = null;
|
client = null;
|
||||||
OnClientDisconnected?.Invoke();
|
OnClientDisconnected?.Invoke();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void HandleInnerClientError(TransportError arg1, string arg2)
|
void HandleInnerClientError(TransportError arg1, string arg2) => OnClientError?.Invoke(arg1, $"inner: {arg2}");
|
||||||
{
|
|
||||||
OnClientError?.Invoke(arg1, $"inner: {arg2}");
|
|
||||||
}
|
|
||||||
|
|
||||||
private void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel)
|
void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel) => client?.OnReceiveRaw(data, channel);
|
||||||
{
|
|
||||||
_client?.OnReceiveRaw(data, channel);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void HandleInnerClientConnected()
|
void HandleInnerClientConnected() =>
|
||||||
{
|
client = new EncryptedConnection(
|
||||||
_client = new EncryptedConnection(
|
credentials,
|
||||||
_credentials,
|
|
||||||
true,
|
true,
|
||||||
(segment, channel) => inner.ClientSend(segment, channel),
|
(segment, channel) => Inner.ClientSend(segment, channel),
|
||||||
(segment, channel) => OnClientDataReceived?.Invoke(segment, channel),
|
(segment, channel) => OnClientDataReceived?.Invoke(segment, channel),
|
||||||
() =>
|
() =>
|
||||||
{
|
{
|
||||||
@ -160,25 +149,23 @@ private void HandleInnerClientConnected()
|
|||||||
ClientDisconnect();
|
ClientDisconnect();
|
||||||
},
|
},
|
||||||
HandleClientValidateServerPubKey);
|
HandleClientValidateServerPubKey);
|
||||||
}
|
|
||||||
|
|
||||||
private bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
|
bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
|
||||||
{
|
{
|
||||||
switch (clientValidateServerPubKey)
|
switch (ClientValidateServerPubKey)
|
||||||
{
|
{
|
||||||
case ValidationMode.Off:
|
case ValidationMode.Off:
|
||||||
return true;
|
return true;
|
||||||
case ValidationMode.List:
|
case ValidationMode.List:
|
||||||
return Array.IndexOf(clientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
|
return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
|
||||||
case ValidationMode.Callback:
|
case ValidationMode.Callback:
|
||||||
return onClientValidateServerPubKey(pubKeyInfo);
|
return OnClientValidateServerPubKey(pubKeyInfo);
|
||||||
default:
|
default:
|
||||||
throw new ArgumentOutOfRangeException();
|
throw new ArgumentOutOfRangeException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Awake()
|
void Awake() =>
|
||||||
{
|
|
||||||
// check if encryption via hardware acceleration is supported.
|
// check if encryption via hardware acceleration is supported.
|
||||||
// this can be useful to know for low end devices.
|
// this can be useful to know for low end devices.
|
||||||
//
|
//
|
||||||
@ -188,27 +175,26 @@ void Awake()
|
|||||||
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs
|
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs
|
||||||
// which Unity does not support yet.
|
// which Unity does not support yet.
|
||||||
Debug.Log($"EncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}");
|
Debug.Log($"EncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}");
|
||||||
}
|
|
||||||
|
|
||||||
public override bool Available() => inner.Available();
|
public override bool Available() => Inner.Available();
|
||||||
|
|
||||||
public override bool ClientConnected() => _client != null && _client.IsReady;
|
public override bool ClientConnected() => client != null && client.IsReady;
|
||||||
|
|
||||||
public override void ClientConnect(string address)
|
public override void ClientConnect(string address)
|
||||||
{
|
{
|
||||||
switch (clientValidateServerPubKey)
|
switch (ClientValidateServerPubKey)
|
||||||
{
|
{
|
||||||
case ValidationMode.Off:
|
case ValidationMode.Off:
|
||||||
break;
|
break;
|
||||||
case ValidationMode.List:
|
case ValidationMode.List:
|
||||||
if (clientTrustedPubKeySignatures == null || clientTrustedPubKeySignatures.Length == 0)
|
if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0)
|
||||||
{
|
{
|
||||||
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty.");
|
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case ValidationMode.Callback:
|
case ValidationMode.Callback:
|
||||||
if (onClientValidateServerPubKey == null)
|
if (OnClientValidateServerPubKey == null)
|
||||||
{
|
{
|
||||||
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set");
|
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set");
|
||||||
return;
|
return;
|
||||||
@ -217,95 +203,79 @@ public override void ClientConnect(string address)
|
|||||||
default:
|
default:
|
||||||
throw new ArgumentOutOfRangeException();
|
throw new ArgumentOutOfRangeException();
|
||||||
}
|
}
|
||||||
_credentials = EncryptionCredentials.Generate();
|
credentials = EncryptionCredentials.Generate();
|
||||||
inner.OnClientConnected = HandleInnerClientConnected;
|
Inner.OnClientConnected = HandleInnerClientConnected;
|
||||||
inner.OnClientDataReceived = HandleInnerClientDataReceived;
|
Inner.OnClientDataReceived = HandleInnerClientDataReceived;
|
||||||
inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel);
|
Inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel);
|
||||||
inner.OnClientError = HandleInnerClientError;
|
Inner.OnClientError = HandleInnerClientError;
|
||||||
inner.OnClientDisconnected = HandleInnerClientDisconnected;
|
Inner.OnClientDisconnected = HandleInnerClientDisconnected;
|
||||||
inner.ClientConnect(address);
|
Inner.ClientConnect(address);
|
||||||
}
|
}
|
||||||
|
|
||||||
public override void ClientSend(ArraySegment<byte> segment, int channelId = Channels.Reliable) =>
|
public override void ClientSend(ArraySegment<byte> segment, int channelId = Channels.Reliable) =>
|
||||||
_client?.Send(segment, channelId);
|
client?.Send(segment, channelId);
|
||||||
|
|
||||||
public override void ClientDisconnect() => inner.ClientDisconnect();
|
public override void ClientDisconnect() => Inner.ClientDisconnect();
|
||||||
|
|
||||||
public override Uri ServerUri() => inner.ServerUri();
|
public override Uri ServerUri() => Inner.ServerUri();
|
||||||
|
|
||||||
public override bool ServerActive() => inner.ServerActive();
|
public override bool ServerActive() => Inner.ServerActive();
|
||||||
|
|
||||||
public override void ServerStart()
|
public override void ServerStart()
|
||||||
{
|
{
|
||||||
if (serverLoadKeyPairFromFile)
|
if (ServerLoadKeyPairFromFile)
|
||||||
{
|
credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath);
|
||||||
_credentials = EncryptionCredentials.LoadFromFile(serverKeypairPath);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
credentials = EncryptionCredentials.Generate();
|
||||||
_credentials = EncryptionCredentials.Generate();
|
|
||||||
}
|
|
||||||
#pragma warning disable CS0618 // Type or member is obsolete
|
#pragma warning disable CS0618 // Type or member is obsolete
|
||||||
inner.OnServerConnected = HandleInnerServerConnected;
|
Inner.OnServerConnected = HandleInnerServerConnected;
|
||||||
#pragma warning restore CS0618 // Type or member is obsolete
|
#pragma warning restore CS0618 // Type or member is obsolete
|
||||||
inner.OnServerConnectedWithAddress = HandleInnerServerConnected;
|
Inner.OnServerConnectedWithAddress = HandleInnerServerConnected;
|
||||||
inner.OnServerDataReceived = HandleInnerServerDataReceived;
|
Inner.OnServerDataReceived = HandleInnerServerDataReceived;
|
||||||
inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel);
|
Inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel);
|
||||||
inner.OnServerError = HandleInnerServerError;
|
Inner.OnServerError = HandleInnerServerError;
|
||||||
inner.OnServerDisconnected = HandleInnerServerDisconnected;
|
Inner.OnServerDisconnected = HandleInnerServerDisconnected;
|
||||||
inner.ServerStart();
|
Inner.ServerStart();
|
||||||
}
|
}
|
||||||
|
|
||||||
public override void ServerSend(int connectionId, ArraySegment<byte> segment, int channelId = Channels.Reliable)
|
public override void ServerSend(int connectionId, ArraySegment<byte> segment, int channelId = Channels.Reliable)
|
||||||
{
|
{
|
||||||
if (_serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
|
if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
|
||||||
{
|
|
||||||
connection.Send(segment, channelId);
|
connection.Send(segment, channelId);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public override void ServerDisconnect(int connectionId)
|
public override void ServerDisconnect(int connectionId) =>
|
||||||
{
|
|
||||||
// cleanup is done via inners disconnect event
|
// cleanup is done via inners disconnect event
|
||||||
inner.ServerDisconnect(connectionId);
|
Inner.ServerDisconnect(connectionId);
|
||||||
}
|
|
||||||
|
|
||||||
public override string ServerGetClientAddress(int connectionId) => inner.ServerGetClientAddress(connectionId);
|
public override string ServerGetClientAddress(int connectionId) => Inner.ServerGetClientAddress(connectionId);
|
||||||
|
|
||||||
public override void ServerStop() => inner.ServerStop();
|
public override void ServerStop() => Inner.ServerStop();
|
||||||
|
|
||||||
public override int GetMaxPacketSize(int channelId = Channels.Reliable) =>
|
public override int GetMaxPacketSize(int channelId = Channels.Reliable) =>
|
||||||
inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
|
Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
|
||||||
|
|
||||||
public override void Shutdown() => inner.Shutdown();
|
public override void Shutdown() => Inner.Shutdown();
|
||||||
|
|
||||||
public override void ClientEarlyUpdate()
|
public override void ClientEarlyUpdate() => Inner.ClientEarlyUpdate();
|
||||||
{
|
|
||||||
inner.ClientEarlyUpdate();
|
|
||||||
}
|
|
||||||
|
|
||||||
public override void ClientLateUpdate()
|
public override void ClientLateUpdate()
|
||||||
{
|
{
|
||||||
inner.ClientLateUpdate();
|
Inner.ClientLateUpdate();
|
||||||
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
|
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
|
||||||
_client?.TickNonReady(NetworkTime.localTime);
|
client?.TickNonReady(NetworkTime.localTime);
|
||||||
Profiler.EndSample();
|
Profiler.EndSample();
|
||||||
}
|
}
|
||||||
|
|
||||||
public override void ServerEarlyUpdate()
|
public override void ServerEarlyUpdate() => Inner.ServerEarlyUpdate();
|
||||||
{
|
|
||||||
inner.ServerEarlyUpdate();
|
|
||||||
}
|
|
||||||
|
|
||||||
public override void ServerLateUpdate()
|
public override void ServerLateUpdate()
|
||||||
{
|
{
|
||||||
inner.ServerLateUpdate();
|
Inner.ServerLateUpdate();
|
||||||
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
|
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
|
||||||
// Reverse iteration as entries can be removed while updating
|
// Reverse iteration as entries can be removed while updating
|
||||||
for (int i = _serverPendingConnections.Count - 1; i >= 0; i--)
|
for (int i = serverPendingConnections.Count - 1; i >= 0; i--)
|
||||||
{
|
serverPendingConnections[i].TickNonReady(NetworkTime.time);
|
||||||
_serverPendingConnections[i].TickNonReady(NetworkTime.time);
|
|
||||||
}
|
|
||||||
Profiler.EndSample();
|
Profiler.EndSample();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,314 @@
|
|||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Net;
|
||||||
|
using Mirror.BouncyCastle.Crypto;
|
||||||
|
using UnityEngine;
|
||||||
|
using UnityEngine.Profiling;
|
||||||
|
using UnityEngine.Serialization;
|
||||||
|
using Debug = UnityEngine.Debug;
|
||||||
|
|
||||||
|
namespace Mirror.Transports.Encryption
|
||||||
|
{
|
||||||
|
[HelpURL("https://mirror-networking.gitbook.io/docs/manual/transports/encryption-transport")]
|
||||||
|
public class ThreadedEncryptionTransport : ThreadedTransport, PortTransport
|
||||||
|
{
|
||||||
|
public override bool IsEncrypted => true;
|
||||||
|
public override string EncryptionCipher => "AES256-GCM";
|
||||||
|
[FormerlySerializedAs("inner")]
|
||||||
|
public ThreadedTransport Inner;
|
||||||
|
|
||||||
|
public ushort Port
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
if (Inner is PortTransport portTransport)
|
||||||
|
return portTransport.Port;
|
||||||
|
|
||||||
|
Debug.LogError($"ThreadedEncryptionTransport can't get Port because {Inner} is not a PortTransport");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
set
|
||||||
|
{
|
||||||
|
if (Inner is PortTransport portTransport)
|
||||||
|
{
|
||||||
|
portTransport.Port = value;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Debug.LogError($"ThreadedEncryptionTransport can't set Port because {Inner} is not a PortTransport");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum ValidationMode
|
||||||
|
{
|
||||||
|
Off,
|
||||||
|
List,
|
||||||
|
Callback
|
||||||
|
}
|
||||||
|
|
||||||
|
[FormerlySerializedAs("clientValidateServerPubKey")]
|
||||||
|
public ValidationMode ClientValidateServerPubKey;
|
||||||
|
[FormerlySerializedAs("clientTrustedPubKeySignatures")]
|
||||||
|
[Tooltip("List of public key fingerprints the client will accept")]
|
||||||
|
public string[] ClientTrustedPubKeySignatures;
|
||||||
|
/// <summary>
|
||||||
|
/// Called when a client connects to a server
|
||||||
|
/// ATTENTION: NOT THREAD SAFE.
|
||||||
|
/// This will be called on the worker thread.
|
||||||
|
/// </summary>
|
||||||
|
public Func<PubKeyInfo, bool> OnClientValidateServerPubKey;
|
||||||
|
[FormerlySerializedAs("serverLoadKeyPairFromFile")]
|
||||||
|
public bool ServerLoadKeyPairFromFile;
|
||||||
|
[FormerlySerializedAs("serverKeypairPath")]
|
||||||
|
public string ServerKeypairPath = "./server-keys.json";
|
||||||
|
|
||||||
|
EncryptedConnection client;
|
||||||
|
|
||||||
|
readonly Dictionary<int, EncryptedConnection> serverConnections = new Dictionary<int, EncryptedConnection>();
|
||||||
|
|
||||||
|
readonly List<EncryptedConnection> serverPendingConnections =
|
||||||
|
new List<EncryptedConnection>();
|
||||||
|
|
||||||
|
EncryptionCredentials credentials;
|
||||||
|
public string EncryptionPublicKeyFingerprint => credentials?.PublicKeyFingerprint;
|
||||||
|
public byte[] EncryptionPublicKey => credentials?.PublicKeySerialized;
|
||||||
|
|
||||||
|
// Used for threaded time keeping as unitys Time.time is not thread safe
|
||||||
|
Stopwatch stopwatch = Stopwatch.StartNew();
|
||||||
|
|
||||||
|
void ServerRemoveFromPending(EncryptedConnection con)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < serverPendingConnections.Count; i++)
|
||||||
|
if (serverPendingConnections[i] == con)
|
||||||
|
{
|
||||||
|
// remove by swapping with last
|
||||||
|
int lastIndex = serverPendingConnections.Count - 1;
|
||||||
|
serverPendingConnections[i] = serverPendingConnections[lastIndex];
|
||||||
|
serverPendingConnections.RemoveAt(lastIndex);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void HandleInnerServerDisconnected(int connId)
|
||||||
|
{
|
||||||
|
if (serverConnections.TryGetValue(connId, out EncryptedConnection con))
|
||||||
|
{
|
||||||
|
ServerRemoveFromPending(con);
|
||||||
|
serverConnections.Remove(connId);
|
||||||
|
}
|
||||||
|
OnThreadedServerDisconnected(connId);
|
||||||
|
}
|
||||||
|
|
||||||
|
void HandleInnerServerError(int connId, TransportError type, string msg) => OnThreadedServerError(connId, type, $"inner: {msg}");
|
||||||
|
|
||||||
|
void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
|
||||||
|
{
|
||||||
|
if (serverConnections.TryGetValue(connId, out EncryptedConnection c))
|
||||||
|
c.OnReceiveRaw(data, channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void HandleInnerServerConnected(int connId) => HandleInnerServerConnected(connId, Inner.ServerGetClientAddress(connId));
|
||||||
|
|
||||||
|
void HandleInnerServerConnected(int connId, string clientRemoteAddress)
|
||||||
|
{
|
||||||
|
Debug.Log($"[ThreadedEncryptionTransport] New connection #{connId} from {clientRemoteAddress}");
|
||||||
|
EncryptedConnection ec = null;
|
||||||
|
ec = new EncryptedConnection(
|
||||||
|
credentials,
|
||||||
|
false,
|
||||||
|
(segment, channel) => Inner.ServerSend(connId, segment, channel),
|
||||||
|
(segment, channel) => OnThreadedServerReceive(connId, segment, channel),
|
||||||
|
() =>
|
||||||
|
{
|
||||||
|
Debug.Log($"[ThreadedEncryptionTransport] Connection #{connId} is ready");
|
||||||
|
// ReSharper disable once AccessToModifiedClosure
|
||||||
|
ServerRemoveFromPending(ec);
|
||||||
|
OnThreadedServerConnected(connId, new IPEndPoint(IPAddress.Parse(clientRemoteAddress), 0));
|
||||||
|
},
|
||||||
|
(type, msg) =>
|
||||||
|
{
|
||||||
|
OnThreadedServerError(connId, type, msg);
|
||||||
|
ServerDisconnect(connId);
|
||||||
|
});
|
||||||
|
serverConnections.Add(connId, ec);
|
||||||
|
serverPendingConnections.Add(ec);
|
||||||
|
}
|
||||||
|
|
||||||
|
void HandleInnerClientDisconnected()
|
||||||
|
{
|
||||||
|
client = null;
|
||||||
|
OnThreadedClientDisconnected();
|
||||||
|
}
|
||||||
|
|
||||||
|
void HandleInnerClientError(TransportError arg1, string arg2) => OnThreadedClientError(arg1, $"inner: {arg2}");
|
||||||
|
|
||||||
|
void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel) => client?.OnReceiveRaw(data, channel);
|
||||||
|
|
||||||
|
void HandleInnerClientConnected() =>
|
||||||
|
client = new EncryptedConnection(
|
||||||
|
credentials,
|
||||||
|
true,
|
||||||
|
(segment, channel) => Inner.ClientSend(segment, channel),
|
||||||
|
(segment, channel) => OnThreadedClientReceive(segment, channel),
|
||||||
|
() =>
|
||||||
|
{
|
||||||
|
OnThreadedClientConnected();
|
||||||
|
},
|
||||||
|
(type, msg) =>
|
||||||
|
{
|
||||||
|
OnThreadedClientError(type, msg);
|
||||||
|
ClientDisconnect();
|
||||||
|
},
|
||||||
|
HandleClientValidateServerPubKey);
|
||||||
|
|
||||||
|
bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
|
||||||
|
{
|
||||||
|
switch (ClientValidateServerPubKey)
|
||||||
|
{
|
||||||
|
case ValidationMode.Off:
|
||||||
|
return true;
|
||||||
|
case ValidationMode.List:
|
||||||
|
return Array.IndexOf(ClientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
|
||||||
|
case ValidationMode.Callback:
|
||||||
|
return OnClientValidateServerPubKey(pubKeyInfo);
|
||||||
|
default:
|
||||||
|
throw new ArgumentOutOfRangeException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void Awake()
|
||||||
|
{
|
||||||
|
base.Awake();
|
||||||
|
// check if encryption via hardware acceleration is supported.
|
||||||
|
// this can be useful to know for low end devices.
|
||||||
|
//
|
||||||
|
// hardware acceleration requires netcoreapp3.0 or later:
|
||||||
|
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/AesUtilities.cs#L18
|
||||||
|
// because AesEngine_x86 requires System.Runtime.Intrinsics.X86:
|
||||||
|
// https://github.com/bcgit/bc-csharp/blob/449940429c57686a6fcf6bfbb4d368dec19d906e/crypto/src/crypto/engines/AesEngine_X86.cs
|
||||||
|
// which Unity does not support yet.
|
||||||
|
Debug.Log($"ThreadedEncryptionTransport: IsHardwareAccelerated={AesUtilities.IsHardwareAccelerated}");
|
||||||
|
}
|
||||||
|
|
||||||
|
public override bool Available() => Inner.Available();
|
||||||
|
|
||||||
|
protected override void ThreadedClientConnect(string address)
|
||||||
|
{
|
||||||
|
switch (ClientValidateServerPubKey)
|
||||||
|
{
|
||||||
|
case ValidationMode.Off:
|
||||||
|
break;
|
||||||
|
case ValidationMode.List:
|
||||||
|
if (ClientTrustedPubKeySignatures == null || ClientTrustedPubKeySignatures.Length == 0)
|
||||||
|
{
|
||||||
|
OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case ValidationMode.Callback:
|
||||||
|
if (OnClientValidateServerPubKey == null)
|
||||||
|
{
|
||||||
|
OnThreadedClientError(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new ArgumentOutOfRangeException();
|
||||||
|
}
|
||||||
|
credentials = EncryptionCredentials.Generate();
|
||||||
|
Inner.OnClientConnected = HandleInnerClientConnected;
|
||||||
|
Inner.OnClientDataReceived = HandleInnerClientDataReceived;
|
||||||
|
Inner.OnClientDataSent = (bytes, channel) => OnThreadedClientSend(bytes, channel);
|
||||||
|
Inner.OnClientError = HandleInnerClientError;
|
||||||
|
Inner.OnClientDisconnected = HandleInnerClientDisconnected;
|
||||||
|
Inner.ClientConnect(address);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void ThreadedClientConnect(Uri address) => Inner.ClientConnect(address);
|
||||||
|
|
||||||
|
protected override void ThreadedClientSend(ArraySegment<byte> segment, int channelId) =>
|
||||||
|
client?.Send(segment, channelId);
|
||||||
|
|
||||||
|
protected override void ThreadedClientDisconnect() => Inner.ClientDisconnect();
|
||||||
|
|
||||||
|
protected override void ThreadedServerStart()
|
||||||
|
{
|
||||||
|
if (ServerLoadKeyPairFromFile)
|
||||||
|
credentials = EncryptionCredentials.LoadFromFile(ServerKeypairPath);
|
||||||
|
else
|
||||||
|
credentials = EncryptionCredentials.Generate();
|
||||||
|
#pragma warning disable CS0618 // Type or member is obsolete
|
||||||
|
Inner.OnServerConnected = HandleInnerServerConnected;
|
||||||
|
#pragma warning restore CS0618 // Type or member is obsolete
|
||||||
|
Inner.OnServerConnectedWithAddress = HandleInnerServerConnected;
|
||||||
|
Inner.OnServerDataReceived = HandleInnerServerDataReceived;
|
||||||
|
Inner.OnServerDataSent = (connId, bytes, channel) => OnThreadedServerSend(connId, bytes, channel);
|
||||||
|
Inner.OnServerError = HandleInnerServerError;
|
||||||
|
Inner.OnServerDisconnected = HandleInnerServerDisconnected;
|
||||||
|
Inner.ServerStart();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void ThreadedServerSend(int connectionId, ArraySegment<byte> segment, int channelId)
|
||||||
|
{
|
||||||
|
if (serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
|
||||||
|
connection.Send(segment, channelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void ThreadedServerDisconnect(int connectionId) =>
|
||||||
|
// cleanup is done via inners disconnect event
|
||||||
|
Inner.ServerDisconnect(connectionId);
|
||||||
|
|
||||||
|
protected override void ThreadedClientEarlyUpdate() {}
|
||||||
|
|
||||||
|
protected override void ThreadedServerStop() => Inner.ServerStop();
|
||||||
|
|
||||||
|
public override Uri ServerUri() => Inner.ServerUri();
|
||||||
|
|
||||||
|
public override int GetMaxPacketSize(int channelId = Channels.Reliable) =>
|
||||||
|
Inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
|
||||||
|
|
||||||
|
protected override void ThreadedShutdown() => Inner.Shutdown();
|
||||||
|
|
||||||
|
public override void ClientEarlyUpdate()
|
||||||
|
{
|
||||||
|
base.ClientEarlyUpdate();
|
||||||
|
Inner.ClientEarlyUpdate();
|
||||||
|
}
|
||||||
|
|
||||||
|
public override void ClientLateUpdate()
|
||||||
|
{
|
||||||
|
base.ClientLateUpdate();
|
||||||
|
Inner.ClientLateUpdate();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void ThreadedClientLateUpdate()
|
||||||
|
{
|
||||||
|
Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate");
|
||||||
|
client?.TickNonReady(stopwatch.Elapsed.TotalSeconds);
|
||||||
|
Profiler.EndSample();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void ThreadedServerEarlyUpdate() {}
|
||||||
|
|
||||||
|
public override void ServerEarlyUpdate()
|
||||||
|
{
|
||||||
|
base.ServerEarlyUpdate();
|
||||||
|
Inner.ServerEarlyUpdate();
|
||||||
|
}
|
||||||
|
|
||||||
|
public override void ServerLateUpdate()
|
||||||
|
{
|
||||||
|
base.ServerLateUpdate();
|
||||||
|
Inner.ServerLateUpdate();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void ThreadedServerLateUpdate()
|
||||||
|
{
|
||||||
|
Profiler.BeginSample("ThreadedEncryptionTransport.ServerLateUpdate");
|
||||||
|
// Reverse iteration as entries can be removed while updating
|
||||||
|
for (int i = serverPendingConnections.Count - 1; i >= 0; i--)
|
||||||
|
serverPendingConnections[i].TickNonReady(stopwatch.Elapsed.TotalSeconds);
|
||||||
|
Profiler.EndSample();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 5d3e310924fb49c195391b9699f20809
|
||||||
|
MonoImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
serializedVersion: 2
|
||||||
|
defaultReferences: []
|
||||||
|
executionOrder: 0
|
||||||
|
icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
Loading…
Reference in New Issue
Block a user