feat: EncryptionTransport (#3768)

* initial working transport

* code cleanup & transport wrap tests

* better connection tests

* Handle bouncycastle exceptions

* clean up usings

* Mirror icon :)

* list to allow for removing entries during loop

* Profiler sampling

* Unity 2019 compat

* code style

* pubkey validation

* use builtin aes engine selector

this is overly optimistic, as the hardware accelerated engine is only available on .net core 3 or higher

* Older unity version fix
This commit is contained in:
Robin Rolf 2024-03-05 13:51:34 +00:00 committed by MrGadget
parent f54b281af7
commit d929a9a9ea
25 changed files with 1949 additions and 1 deletions

View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 31ff83bf6d2e72542adcbe2c21383f4a
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,33 @@
fileFormatVersion: 2
guid: a67bf078294b36e4686b9912bf172010
PluginImporter:
externalObjects: {}
serializedVersion: 2
iconMap: {}
executionOrder: {}
defineConstraints: []
isPreloaded: 0
isOverridable: 0
isExplicitlyReferenced: 0
validateReferences: 1
platformData:
- first:
Any:
second:
enabled: 1
settings: {}
- first:
Editor: Editor
second:
enabled: 0
settings:
DefaultValueInitialized: true
- first:
Windows Store Apps: WindowsStoreApps
second:
enabled: 0
settings:
CPU: AnyCPU
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,13 @@
Copyright (c) 2000-2024 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org).
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sub license, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions: The above copyright notice and this
permission notice shall be included in all copies or substantial portions of the Software.
**THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.**

View File

@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 2b45a99b5583cda419e1f1ec943fec4b
TextScriptImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -24,7 +24,8 @@
"Castle.Core.dll",
"System.Threading.Tasks.Extensions.dll",
"Mono.CecilX.dll",
"nunit.framework.dll"
"nunit.framework.dll",
"BouncyCastle.Cryptography.dll"
],
"autoReferenced": false,
"defineConstraints": [

View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: d5266c80d88c1ca4cb68cf0551780c3f
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,549 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Mirror.Tests.NetworkServers;
using Mirror.Transports.Encryption;
using NUnit.Framework;
namespace Mirror.Tests.Transports
{
public class EncryptionTransportConnectionTest
{
struct Data
{
public byte[] data;
public int channel;
}
private EncryptedConnection server;
private EncryptionCredentials serverCreds;
Queue<Data> serverRecv = new Queue<Data>();
private Action serverReady;
private Action<ArraySegment<byte>, int> serverReceive;
private Func<ArraySegment<byte>, int, bool> shouldServerSend;
private Func<PubKeyInfo, bool> serverValidateKey;
private EncryptedConnection client;
private EncryptionCredentials clientCreds;
Queue<Data> clientRecv = new Queue<Data>();
private Action clientReady;
private Action<ArraySegment<byte>, int> clientReceive;
private Func<ArraySegment<byte>, int, bool> shouldClientSend;
private Func<PubKeyInfo, bool> clientValidateKey;
private double _time;
private double _timestep = 0.05;
class ErrorException : Exception
{
public ErrorException(string msg) : base(msg) {}
}
[SetUp]
public void Setup()
{
serverReady = null;
serverReceive = null;
shouldServerSend = null;
serverValidateKey = null;
clientReady = null;
clientReceive = null;
shouldClientSend = null;
clientValidateKey = null;
clientRecv.Clear();
serverRecv.Clear();
serverCreds = EncryptionCredentials.Generate();
server = new EncryptedConnection(serverCreds, false,
(bytes, channel) =>
{
if (shouldServerSend == null || shouldServerSend(bytes, channel))
clientRecv.Enqueue(new Data
{
data = bytes.ToArray(), channel = channel
});
},
(bytes, channel) =>
{
serverReceive?.Invoke(bytes, channel);
},
() => { serverReady?.Invoke(); },
(error, s) => throw new ErrorException($"{error}: {s}"),
info =>
{
if (serverValidateKey != null) return serverValidateKey(info);
return true;
});
clientCreds = EncryptionCredentials.Generate();
client = new EncryptedConnection(clientCreds, true,
(bytes, channel) =>
{
if (shouldClientSend == null || shouldClientSend(bytes, channel))
serverRecv.Enqueue(new Data
{
data = bytes.ToArray(), channel = channel
});
},
(bytes, channel) =>
{
clientReceive?.Invoke(bytes, channel);
},
() => { clientReady?.Invoke(); },
(error, s) => throw new ErrorException($"{error}: {s}. t={_time}"),
info =>
{
if (clientValidateKey != null) return clientValidateKey(info);
return true;
});
}
private void Pump()
{
_time += _timestep;
while (clientRecv.TryDequeue(out Data data))
{
client.OnReceiveRaw(new ArraySegment<byte>(data.data), data.channel);
}
if (!client.IsReady)
{
client.TickNonReady(_time);
}
while (serverRecv.TryDequeue(out Data data))
{
server.OnReceiveRaw(new ArraySegment<byte>(data.data), data.channel);
}
if (!server.IsReady)
{
server.TickNonReady(_time);
}
}
[TearDown]
public void TearDown()
{
}
[Test]
public void TestHandshakeSuccess()
{
bool isServerReady = false;
bool isClientReady = false;
clientReady = () =>
{
Assert.False(isClientReady); // only called once
Assert.True(client.IsReady); // should be set when called
isClientReady = true;
};
serverReady = () =>
{
Assert.False(isServerReady); // only called once
Assert.True(server.IsReady); // should be set when called
isServerReady = true;
server.Send(new ArraySegment<byte>(new byte[]
{
1, 2, 3
}), Channels.Reliable); // need to send to ready the other side
};
while (!isServerReady || !isClientReady)
{
if (_time > 20)
{
throw new Exception("Timeout.");
}
Pump();
}
}
[Test]
public void TestHandshakeSuccessWithLoss()
{
int clientCount = 0;
shouldClientSend = (data, channel) =>
{
if (channel == Channels.Unreliable)
{
clientCount++;
// drop 75% of packets
return clientCount % 4 == 0;
}
return true;
};
int serverCount = 0;
shouldServerSend = (data, channel) =>
{
if (channel == Channels.Unreliable)
{
serverCount++;
// drop 75% of packets
return serverCount % 4 == 0;
}
return true;
};
TestHandshakeSuccess();
}
private bool ArrayContainsSequence(ArraySegment<byte> haystack, ArraySegment<byte> needle)
{
if (needle.Count == 0)
{
return true;
}
int ni = 0;
for (int hi = 0; hi < haystack.Count; hi++)
{
if (haystack.Array[haystack.Offset + hi] == needle.Array[needle.Offset + ni])
{
ni++;
if (ni == needle.Count)
{
return true;
}
}
else
{
ni = 0;
}
}
return false;
}
[Test]
public void TestUtil()
{
Assert.True(ArrayContainsSequence(new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
}), new ArraySegment<byte>(new byte[]
{
})));
Assert.True(ArrayContainsSequence(new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
}), new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
})));
Assert.True(ArrayContainsSequence(new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
}), new ArraySegment<byte>(new byte[]
{
2, 3
})));
Assert.True(ArrayContainsSequence(new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
}), new ArraySegment<byte>(new byte[]
{
3, 4
})));
Assert.False(ArrayContainsSequence(new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
}), new ArraySegment<byte>(new byte[]
{
1, 3
})));
Assert.False(ArrayContainsSequence(new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
}), new ArraySegment<byte>(new byte[]
{
3, 4, 5
})));
}
[Test]
public void TestDataSecurity()
{
byte[] serverData = Encoding.UTF8.GetBytes("This is very important secret server data");
byte[] clientData = Encoding.UTF8.GetBytes("Super secret data from the client is contained within.");
bool isServerDone = false;
bool isClientDone = false;
clientReady = () =>
{
client.Send(new ArraySegment<byte>(clientData), Channels.Reliable);
};
serverReady = () =>
{
server.Send(new ArraySegment<byte>(serverData), Channels.Reliable);
};
shouldServerSend = (bytes, i) =>
{
if (i == Channels.Reliable)
{
Assert.False(ArrayContainsSequence(bytes, new ArraySegment<byte>(serverData)));
}
return true;
};
shouldClientSend = (bytes, i) =>
{
if (i == Channels.Reliable)
{
Assert.False(ArrayContainsSequence(bytes, new ArraySegment<byte>(clientData)));
}
return true;
};
serverReceive = (bytes, channel) =>
{
Assert.AreEqual(Channels.Reliable, channel);
Assert.AreEqual(bytes, new ArraySegment<byte>(clientData));
Assert.False(isServerDone);
isServerDone = true;
};
clientReceive = (bytes, channel) =>
{
Assert.AreEqual(Channels.Reliable, channel);
Assert.AreEqual(bytes, new ArraySegment<byte>(serverData));
Assert.False(isClientDone);
isClientDone = true;
};
while (!isServerDone || !isClientDone)
{
if (_time > 20)
{
throw new Exception("Timeout.");
}
Pump();
}
}
[Test]
public void TestBadOpCodeErrors()
{
Assert.Throws<ErrorException>(() =>
{
shouldServerSend = (bytes, i) =>
{
// mess up the opcode (first byte)
bytes.Array[bytes.Offset] += 0xAA;
return true;
};
// setup
TestHandshakeSuccess();
});
}
[Test]
public void TestEarlyDataOpCodeErrors()
{
Assert.Throws<ErrorException>(() =>
{
shouldServerSend = (bytes, i) =>
{
// mess up the opcode (first byte)
bytes.Array[bytes.Offset] = 1; // data
return true;
};
// setup
TestHandshakeSuccess();
});
}
[Test]
public void TestUnexpectedAckOpCodeErrors()
{
Assert.Throws<ErrorException>(() =>
{
shouldServerSend = (bytes, i) =>
{
// mess up the opcode (first byte)
bytes.Array[bytes.Offset] = 2; // start, client doesn't expect this
return true;
};
// setup
TestHandshakeSuccess();
});
}
[Test]
public void TestUnexpectedHandshakeOpCodeErrors()
{
Assert.Throws<ErrorException>(() =>
{
shouldClientSend = (bytes, i) =>
{
// mess up the opcode (first byte)
bytes.Array[bytes.Offset] = 3; // ack, server doesn't expect this
return true;
};
// setup
TestHandshakeSuccess();
});
}
[Test]
public void TestUnexpectedFinOpCodeErrors()
{
Assert.Throws<ErrorException>(() =>
{
shouldServerSend = (bytes, i) =>
{
// mess up the opcode (first byte)
bytes.Array[bytes.Offset] = 4; // fin, client doesn't expect this
return true;
};
// setup
TestHandshakeSuccess();
});
}
[Test]
public void TestBadDataErrors()
{
TestHandshakeSuccess();
Assert.Throws<ErrorException>(() =>
{
// setup
shouldServerSend = (bytes, i) =>
{
// mess up a byte in the data
bytes.Array[bytes.Offset + 3] += 1;
return true;
};
server.Send(new ArraySegment<byte>(new byte[]
{
1, 2, 3, 4
}), Channels.Reliable);
Pump();
});
}
[Test]
public void TestBadPubKeyInStartErrors()
{
shouldClientSend = (bytes, i) =>
{
if (bytes.Array[bytes.Offset] == 2 /* HandshakeStart Opcode */)
{
// mess up a byte in the data
bytes.Array[bytes.Offset + 3] += 1;
}
return true;
};
Assert.Throws<ErrorException>(() =>
{
TestHandshakeSuccess();
});
}
[Test]
public void TestBadPubKeyInAckErrors()
{
shouldServerSend = (bytes, i) =>
{
if (bytes.Array[bytes.Offset] == 3 /* HandshakeAck Opcode */)
{
// mess up a byte in the data
bytes.Array[bytes.Offset + 3] += 1;
}
return true;
};
Assert.Throws<ErrorException>(() =>
{
TestHandshakeSuccess();
});
}
[Test]
public void TestDataSizes()
{
List<int> sizes = new List<int>();
sizes.Add(1);
sizes.Add(2);
sizes.Add(3);
sizes.Add(6);
sizes.Add(9);
sizes.Add(16);
sizes.Add(60);
sizes.Add(100);
sizes.Add(200);
sizes.Add(400);
sizes.Add(800);
sizes.Add(1024);
sizes.Add(1025);
sizes.Add(4096);
sizes.Add(1024 * 16);
sizes.Add(1024 * 64);
sizes.Add(1024 * 128);
sizes.Add(1024 * 512);
// removed for performance, these do pass though
//sizes.Add(1024 * 1024);
//sizes.Add(1024 * 1024 * 16);
//sizes.Add(1024 * 1024 * 64); // 64MiB
TestHandshakeSuccess();
var maxSize = sizes.Max();
var sendByte = new byte[maxSize];
for (uint i = 0; i < sendByte.Length; i++)
{
sendByte[i] = (byte)i;
}
int size = -1;
clientReceive = (bytes, channel) =>
{
// Assert.AreEqual is super slow for larger arrays, so do it manually
Assert.AreEqual(bytes.Count, size);
for (int i = 0; i < size; i++)
{
if (bytes.Array[bytes.Offset + i] != sendByte[i])
{
Assert.Fail($"received bytes[{i}] did not match. expected {sendByte[i]}, got {bytes.Array[bytes.Offset + i]}");
}
}
};
foreach (var s in sizes)
{
size = s;
server.Send(new ArraySegment<byte>(sendByte, 0, size), 1);
Pump();
}
}
[Test]
public void TestPubKeyValidationIsCalled()
{
bool clientCalled = false;
clientValidateKey = info =>
{
Assert.AreEqual(new ArraySegment<byte>(serverCreds.PublicKeySerialized), info.Serialized);
Assert.AreEqual(serverCreds.PublicKeyFingerprint, info.Fingerprint);
clientCalled = true;
return true;
};
bool serverCalled = false;
serverValidateKey = info =>
{
Assert.AreEqual(clientCreds.PublicKeyFingerprint, info.Fingerprint);
serverCalled = true;
return true;
};
TestHandshakeSuccess();
Assert.IsTrue(clientCalled);
Assert.IsTrue(serverCalled);
}
[Test]
public void TestClientPubKeyValidationErrors()
{
clientValidateKey = info => false;
Assert.Throws<ErrorException>(() =>
{
TestHandshakeSuccess();
});
}
[Test]
public void TestServerPubKeyValidationErrors()
{
serverValidateKey = info => false;
Assert.Throws<ErrorException>(() =>
{
TestHandshakeSuccess();
});
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 6132bc4b559a42b88bd94cc25e1390bf
timeCreated: 1708170265

View File

@ -0,0 +1,234 @@
using System;
using Mirror.Transports.Encryption;
using NSubstitute;
using NUnit.Framework;
using UnityEngine;
namespace Mirror.Tests.Transports
{
// This is mostly a copy of MiddlewareTransport, with the stuff requiring actual connections to be setup deleted
[Description("Test to make sure inner methods are called when using Encryption Transport")]
public class EncryptionTransportTransportTest
{
Transport inner;
EncryptionTransport encryption;
[SetUp]
public void Setup()
{
inner = Substitute.For<Transport>();
GameObject gameObject = new GameObject();
encryption = gameObject.AddComponent<EncryptionTransport>();
encryption.inner = inner;
}
[TearDown]
public void TearDown()
{
GameObject.DestroyImmediate(encryption.gameObject);
}
[Test]
[TestCase(true)]
[TestCase(false)]
public void TestAvailable(bool available)
{
inner.Available().Returns(available);
Assert.That(encryption.Available(), Is.EqualTo(available));
inner.Received(1).Available();
}
[Test]
[TestCase(Channels.Reliable, 4000)]
[TestCase(Channels.Reliable, 2000)]
[TestCase(Channels.Unreliable, 4000)]
public void TestGetMaxPacketSize(int channel, int packageSize)
{
inner.GetMaxPacketSize(Arg.Any<int>()).Returns(packageSize);
Assert.That(encryption.GetMaxPacketSize(channel), Is.EqualTo(packageSize - EncryptedConnection.Overhead));
inner.Received(1).GetMaxPacketSize(Arg.Is<int>(x => x == channel));
inner.Received(0).GetMaxPacketSize(Arg.Is<int>(x => x != channel));
}
[Test]
public void TestShutdown()
{
encryption.Shutdown();
inner.Received(1).Shutdown();
}
[Test]
[TestCase("localhost")]
[TestCase("example.com")]
public void TestClientConnect(string address)
{
encryption.ClientConnect(address);
inner.Received(1).ClientConnect(address);
inner.Received(0).ClientConnect(Arg.Is<string>(x => x != address));
}
[Test]
[TestCase(true)]
[TestCase(false)]
public void TestClientConnected(bool connected)
{
inner.ClientConnected().Returns(connected);
Assert.That(encryption.ClientConnected(), Is.EqualTo(false)); // not testing connection handshaking here
}
[Test]
public void TestClientDisconnect()
{
encryption.ClientDisconnect();
inner.Received(1).ClientDisconnect();
}
[Test]
[TestCase(true)]
[TestCase(false)]
public void TestServerActive(bool active)
{
inner.ServerActive().Returns(active);
Assert.That(encryption.ServerActive(), Is.EqualTo(active));
inner.Received(1).ServerActive();
}
[Test]
public void TestServerStart()
{
encryption.ServerStart();
inner.Received(1).ServerStart();
}
[Test]
public void TestServerStop()
{
encryption.ServerStop();
inner.Received(1).ServerStop();
}
[Test]
[TestCase(0, "tcp4://localhost:7777")]
[TestCase(19, "tcp4://example.com:7777")]
public void TestServerGetClientAddress(int id, string result)
{
inner.ServerGetClientAddress(id).Returns(result);
Assert.That(encryption.ServerGetClientAddress(id), Is.EqualTo(result));
inner.Received(1).ServerGetClientAddress(id);
inner.Received(0).ServerGetClientAddress(Arg.Is<int>(x => x != id));
}
[Test]
[TestCase("tcp4://localhost:7777")]
[TestCase("tcp4://example.com:7777")]
public void TestServerUri(string address)
{
Uri uri = new Uri(address);
inner.ServerUri().Returns(uri);
Assert.That(encryption.ServerUri(), Is.EqualTo(uri));
inner.Received(1).ServerUri();
}
[Test]
public void TestClientDisconnectedCallback()
{
int called = 0;
encryption.OnClientDisconnected = () =>
{
called++;
};
// connect to give callback to inner
encryption.ClientConnect("localhost");
inner.OnClientDisconnected.Invoke();
Assert.That(called, Is.EqualTo(1));
inner.OnClientDisconnected.Invoke();
Assert.That(called, Is.EqualTo(2));
}
[Test]
public void TestClientErrorCallback()
{
int called = 0;
encryption.OnClientError = (error, reason) =>
{
called++;
Assert.That(error, Is.EqualTo(TransportError.Unexpected));
};
// connect to give callback to inner
encryption.ClientConnect("localhost");
inner.OnClientError.Invoke(TransportError.Unexpected, "");
Assert.That(called, Is.EqualTo(1));
inner.OnClientError.Invoke(TransportError.Unexpected, "");
Assert.That(called, Is.EqualTo(2));
}
[Test]
[TestCase(0)]
[TestCase(1)]
[TestCase(19)]
public void TestServerDisconnectedCallback(int id)
{
int called = 0;
encryption.OnServerDisconnected = (i) =>
{
called++;
Assert.That(i, Is.EqualTo(id));
};
// start to give callback to inner
encryption.ServerStart();
inner.OnServerDisconnected.Invoke(id);
Assert.That(called, Is.EqualTo(1));
inner.OnServerDisconnected.Invoke(id);
Assert.That(called, Is.EqualTo(2));
}
[Test]
[TestCase(0)]
[TestCase(1)]
[TestCase(19)]
public void TestServerErrorCallback(int id)
{
int called = 0;
encryption.OnServerError = (i, error, reason) =>
{
called++;
Assert.That(i, Is.EqualTo(id));
Assert.That(error, Is.EqualTo(TransportError.Unexpected));
};
// start to give callback to inner
encryption.ServerStart();
inner.OnServerError.Invoke(id, TransportError.Unexpected, "");
Assert.That(called, Is.EqualTo(1));
inner.OnServerError.Invoke(id, TransportError.Unexpected, "");
Assert.That(called, Is.EqualTo(2));
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 24f6aad90a7a4a42bba0473d5b27ebe8
timeCreated: 1708386085

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 741b3c7e5d0842049ff50a2f6e27ca12
timeCreated: 1708015148

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 0d3cd9d7d6e84a578f7e4b384ff813f1
timeCreated: 1708793986

View File

@ -0,0 +1,18 @@
{
"name": "EncryptionTransportEditor",
"rootNamespace": "",
"references": [
"GUID:627104647b9c04b4ebb8978a92ecac63"
],
"includePlatforms": [
"Editor"
],
"excludePlatforms": [],
"allowUnsafeCode": false,
"overrideReferences": false,
"precompiledReferences": [],
"autoReferenced": true,
"defineConstraints": [],
"versionDefines": [],
"noEngineReferences": false
}

View File

@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 4c9c7b0ef83e6e945b276d644816a489
AssemblyDefinitionImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,81 @@
using UnityEditor;
using UnityEngine;
namespace Mirror.Transports.Encryption
{
[CustomEditor(typeof(EncryptionTransport), true)]
public class EncryptionTransportInspector : UnityEditor.Editor
{
SerializedProperty innerProperty;
SerializedProperty clientValidatesServerPubKeyProperty;
SerializedProperty clientTrustedPubKeySignaturesProperty;
SerializedProperty serverKeypairPathProperty;
SerializedProperty serverLoadKeyPairFromFileProperty;
// Assuming proper SerializedProperty definitions for properties
// Add more SerializedProperty fields related to different modes as needed
void OnEnable()
{
innerProperty = serializedObject.FindProperty("inner");
clientValidatesServerPubKeyProperty = serializedObject.FindProperty("clientValidateServerPubKey");
clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("clientTrustedPubKeySignatures");
serverKeypairPathProperty = serializedObject.FindProperty("serverKeypairPath");
serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("serverLoadKeyPairFromFile");
}
public override void OnInspectorGUI()
{
serializedObject.Update();
EditorGUILayout.LabelField("Common", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(innerProperty);
EditorGUILayout.Separator();
// Client Section
EditorGUILayout.LabelField("Client", EditorStyles.boldLabel);
EditorGUILayout.HelpBox("Validating the servers public key is essential for complete man-in-the-middle (MITM) safety, but might not be feasible for all modes of hosting.", MessageType.Info);
EditorGUILayout.PropertyField(clientValidatesServerPubKeyProperty, new GUIContent("Validate Server Public Key"));
EncryptionTransport.ValidationMode validationMode = (EncryptionTransport.ValidationMode)clientValidatesServerPubKeyProperty.enumValueIndex;
switch (validationMode)
{
case EncryptionTransport.ValidationMode.List:
EditorGUILayout.PropertyField(clientTrustedPubKeySignaturesProperty);
break;
case EncryptionTransport.ValidationMode.Callback:
EditorGUILayout.HelpBox("Please set the EncryptionTransport.onClientValidateServerPubKey at runtime.", MessageType.Info);
break;
}
EditorGUILayout.Separator();
// Server Section
EditorGUILayout.LabelField("Server", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(serverLoadKeyPairFromFileProperty, new GUIContent("Load Keypair From File"));
if (serverLoadKeyPairFromFileProperty.boolValue)
{
EditorGUILayout.PropertyField(serverKeypairPathProperty, new GUIContent("Keypair File Path"));
}
if(GUILayout.Button("Generate Key Pair"))
{
EncryptionCredentials keyPair = EncryptionCredentials.Generate();
string path = EditorUtility.SaveFilePanel("Select where to save the keypair", "", "server-keys.json", "json");
if (!string.IsNullOrEmpty(path))
{
keyPair.SaveToFile(path);
EditorUtility.DisplayDialog("KeyPair Saved", $"Successfully saved the keypair.\nThe fingerprint is {keyPair.PublicKeyFingerprint}, you can also retrieve it from the saved json file at any point.", "Ok");
if (validationMode == EncryptionTransport.ValidationMode.List)
{
if (EditorUtility.DisplayDialog("Add key to trusted list?", "Do you also want to add the generated key to the trusted list?", "Yes", "No"))
{
clientTrustedPubKeySignaturesProperty.arraySize++;
clientTrustedPubKeySignaturesProperty.GetArrayElementAtIndex(clientTrustedPubKeySignaturesProperty.arraySize - 1).stringValue = keyPair.PublicKeyFingerprint;
}
}
}
}
serializedObject.ApplyModifiedProperties();
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 871580d2094a46139279d651cec92b5d
timeCreated: 1708794004

View File

@ -0,0 +1,555 @@
using System;
using System.Security.Cryptography;
using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Agreement;
using Org.BouncyCastle.Crypto.Engines;
using Org.BouncyCastle.Crypto.Modes;
using Org.BouncyCastle.Crypto.Parameters;
using UnityEngine.Profiling;
namespace Mirror.Transports.Encryption
{
public class EncryptedConnection
{
// fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed)
private const int NonceSize = 12;
// this is the size of the "checksum" included in each encrypted payload
// 16 bytes/128 bytes is the recommended value for best security
// can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker.
// Setting it lower than 12 bytes is not recommended
private const int MacSizeBytes = 16;
private const int MacSizeBits = MacSizeBytes * 8;
// How much metadata overhead we have for regular packets
public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize;
// After how many seconds of not receiving a handshake packet we should time out
private const double DurationTimeout = 2; // 2s
// After how many seconds to assume the last handshake packet got lost and to resend another one
private const double DurationResend = 0.05; // 50ms
// Static fields for allocation efficiency, makes this not thread safe
// It'd be as easy as using ThreadLocal though to fix that
// Set up a global cipher instance, it is initialised/reset before use
// (AesFastEngine used to exist, but was removed due to side channel issues)
// use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core)
private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine());
// Global byte array to store nonce sent by the remote side, they're used immediately after
private static readonly byte[] ReceiveNonce = new byte[NonceSize];
// buffer for encrypt/decrypt operations, resized larger as needed
// this is also the buffer that will be returned to mirror via ArraySegment
// so any thread safety concerns would need to take extra care here
private static byte[] _tmpCryptBuffer = new byte[2048];
// packet headers
enum OpCodes : byte
{
// start at 1 to maybe filter out random noise
Data = 1,
HandshakeStart = 2,
HandshakeAck = 3,
HandshakeFin = 4,
}
enum State
{
// Waiting for a handshake to arrive
// this is for _sendsFirst:
// - false: OpCodes.HandshakeStart
// - true: Opcodes.HandshakeAck
WaitingHandshake,
// Waiting for a handshake reply/acknowledgement to arrive
// this is for _sendsFirst:
// - false: OpCodes.HandshakeFine
// - true: Opcodes.Data (implicitly)
WaitingHandshakeReply,
// Both sides have confirmed the keys are exchanged and data can be sent freely
Ready
}
private State _state = State.WaitingHandshake;
// Key exchange confirmed and data can be sent freely
public bool IsReady => _state == State.Ready;
// Callback to send off encrypted data
private Action<ArraySegment<byte>, int> _send;
// Callback when received data has been decrypted
private Action<ArraySegment<byte>, int> _receive;
// Callback when the connection becomes ready
private Action _ready;
// On-error callback, disconnect expected
private Action<TransportError, string> _error;
// Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance
// (usually client validates the server key)
private Func<PubKeyInfo, bool> _validateRemoteKey;
// Our asymmetric credentials for the initial DH exchange
private EncryptionCredentials _credentials;
// After no handshake packet in this many seconds, the handshake fails
private double _handshakeTimeout;
// When to assume the last handshake packet got lost and to resend another one
private double _nextHandshakeResend;
// we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in
// so we can update it without creating a new AeadParameters instance
// this might break in the future! (will cause bad data)
private byte[] _nonce = new byte[NonceSize];
private AeadParameters _cipherParametersEncrypt;
private AeadParameters _cipherParametersDecrypt;
/*
* Specifies if we send the first key, then receive ack, then send fin
* Or the opposite if set to false
*
* The client does this, since the fin is not acked explicitly, but by receiving data to decrypt
*/
private readonly bool _sendsFirst;
public EncryptedConnection(EncryptionCredentials credentials,
bool isClient,
Action<ArraySegment<byte>, int> sendAction,
Action<ArraySegment<byte>, int> receiveAction,
Action readyAction,
Action<TransportError, string> errorAction,
Func<PubKeyInfo, bool> validateRemoteKey = null)
{
_credentials = credentials;
_sendsFirst = isClient;
_send = sendAction;
_receive = receiveAction;
_ready = readyAction;
_error = errorAction;
_validateRemoteKey = validateRemoteKey;
}
// Generates a random starting nonce
private static byte[] GenerateStartingNonce()
{
byte[] nonce = new byte[NonceSize];
using (RandomNumberGenerator rng = RandomNumberGenerator.Create())
{
rng.GetBytes(nonce);
}
return nonce;
}
public void OnReceiveRaw(ArraySegment<byte> data, int channel)
{
if (data.Count < 1)
{
_error(TransportError.Unexpected, "Received empty packet");
return;
}
using (NetworkReaderPooled reader = NetworkReaderPool.Get(data))
{
OpCodes opcode = (OpCodes)reader.ReadByte();
switch (opcode)
{
case OpCodes.Data:
// first sender ready is implicit when data is received
if (_sendsFirst && _state == State.WaitingHandshakeReply)
{
SetReady();
}
else if (!IsReady)
{
_error(TransportError.Unexpected, "Unexpected data while not ready.");
}
if (reader.Remaining < Overhead)
{
_error(TransportError.Unexpected, "received data packet smaller than metadata size");
return;
}
ArraySegment<byte> ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize);
reader.ReadBytes(ReceiveNonce, NonceSize);
Profiler.BeginSample("EncryptedConnection.Decrypt");
ArraySegment<byte> plaintext = Decrypt(ciphertext);
Profiler.EndSample();
if (plaintext.Count == 0)
{
// error
return;
}
_receive(plaintext, channel);
break;
case OpCodes.HandshakeStart:
if (_sendsFirst)
{
_error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this.");
return;
}
if (_state == State.WaitingHandshakeReply)
{
// this is fine, packets may arrive out of order
return;
}
_state = State.WaitingHandshakeReply;
ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining));
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
break;
case OpCodes.HandshakeAck:
if (!_sendsFirst)
{
_error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this.");
return;
}
if (IsReady)
{
// this is fine, packets may arrive out of order
return;
}
if (_state == State.WaitingHandshakeReply)
{
// this is fine, packets may arrive out of order
return;
}
_state = State.WaitingHandshakeReply;
ResetTimeouts();
CompleteExchange(reader.ReadBytesSegment(reader.Remaining));
SendHandshakeFin();
break;
case OpCodes.HandshakeFin:
if (_sendsFirst)
{
_error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this.");
return;
}
if (IsReady)
{
// this is fine, packets may arrive out of order
return;
}
if (_state != State.WaitingHandshakeReply)
{
_error(TransportError.Unexpected,
"Received HandshakeFin packet, we didn't expect this yet.");
return;
}
SetReady();
break;
default:
_error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}");
break;
}
}
}
private void SetReady()
{
// done with credentials, null out the reference
_credentials = null;
_state = State.Ready;
_ready();
}
private void ResetTimeouts()
{
_handshakeTimeout = 0;
_nextHandshakeResend = -1;
}
public void Send(ArraySegment<byte> data, int channel)
{
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{
writer.WriteByte((byte)OpCodes.Data);
Profiler.BeginSample("EncryptedConnection.Encrypt");
ArraySegment<byte> encrypted = Encrypt(data);
Profiler.EndSample();
if (encrypted.Count == 0)
{
// error
return;
}
writer.WriteBytes(encrypted.Array, 0, encrypted.Count);
// write nonce after since Encrypt will update it
writer.WriteBytes(_nonce, 0, NonceSize);
_send(writer.ToArraySegment(), channel);
}
}
private ArraySegment<byte> Encrypt(ArraySegment<byte> plaintext)
{
if (plaintext.Count == 0)
{
// Invalid
return new ArraySegment<byte>();
}
// Need to make the nonce unique again before encrypting another message
UpdateNonce();
// Re-initialize the cipher with our cached parameters
Cipher.Init(true, _cipherParametersEncrypt);
// Calculate the expected output size, this should always be input size + mac size
int outSize = Cipher.GetOutputSize(plaintext.Count);
#if UNITY_EDITOR
// expecting the outSize to be input size + MacSize
if (outSize != plaintext.Count + MacSizeBytes)
{
throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}");
}
#endif
// Resize the static buffer to fit
EnsureSize(ref _tmpCryptBuffer, outSize);
int resultLen;
try
{
// Run the plain text through the cipher, ProcessBytes will only process full blocks
resultLen =
Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, _tmpCryptBuffer, 0);
// Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac)
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen);
}
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
//
catch (Exception e)
{
_error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}");
return new ArraySegment<byte>();
}
#if UNITY_EDITOR
// expecting the result length to match the previously calculated input size + MacSize
if (resultLen != outSize)
{
throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
}
#endif
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen);
}
private ArraySegment<byte> Decrypt(ArraySegment<byte> ciphertext)
{
if (ciphertext.Count <= MacSizeBytes)
{
_error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})");
// Invalid
return new ArraySegment<byte>();
}
// Re-initialize the cipher with our cached parameters
Cipher.Init(false, _cipherParametersDecrypt);
// Calculate the expected output size, this should always be input size - mac size
int outSize = Cipher.GetOutputSize(ciphertext.Count);
#if UNITY_EDITOR
// expecting the outSize to be input size - MacSize
if (outSize != ciphertext.Count - MacSizeBytes)
{
throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}");
}
#endif
// Resize the static buffer to fit
EnsureSize(ref _tmpCryptBuffer, outSize);
int resultLen;
try
{
// Run the ciphertext through the cipher, ProcessBytes will only process full blocks
resultLen =
Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, _tmpCryptBuffer, 0);
// Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac)
resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen);
}
// catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types
catch (Exception e)
{
_error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data");
return new ArraySegment<byte>();
}
#if UNITY_EDITOR
// expecting the result length to match the previously calculated input size + MacSize
if (resultLen != outSize)
{
throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})");
}
#endif
return new ArraySegment<byte>(_tmpCryptBuffer, 0, resultLen);
}
private void UpdateNonce()
{
// increment the nonce by one
// we need to ensure the nonce is *always* unique and not reused
// easiest way to do this is by simply incrementing it
for (int i = 0; i < NonceSize; i++)
{
_nonce[i]++;
if (_nonce[i] != 0)
{
break;
}
}
}
private static void EnsureSize(ref byte[] buffer, int size)
{
if (buffer.Length < size)
{
// double buffer to avoid constantly resizing by a few bytes
Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2));
}
}
private void SendHandshakeAndPubKey(OpCodes opcode)
{
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{
writer.WriteByte((byte)opcode);
writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length);
_send(writer.ToArraySegment(), Channels.Unreliable);
}
}
private void SendHandshakeFin()
{
using (NetworkWriterPooled writer = NetworkWriterPool.Get())
{
writer.WriteByte((byte)OpCodes.HandshakeFin);
_send(writer.ToArraySegment(), Channels.Unreliable);
}
}
private void CompleteExchange(ArraySegment<byte> remotePubKeyRaw)
{
AsymmetricKeyParameter remotePubKey;
try
{
remotePubKey = EncryptionCredentials.DeserializePublicKey(remotePubKeyRaw);
}
catch (Exception e)
{
_error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}");
return;
}
if (_validateRemoteKey != null)
{
PubKeyInfo info = new PubKeyInfo
{
Fingerprint = EncryptionCredentials.PubKeyFingerprint(remotePubKeyRaw),
Serialized = remotePubKeyRaw,
Key = remotePubKey
};
if (!_validateRemoteKey(info))
{
_error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. ");
return;
}
}
// Calculate a common symmetric key from our private key and the remotes public key
// This gives us the same key on the other side, with our public key and their remote
// It's like magic, but with math!
ECDHBasicAgreement ecdh = new ECDHBasicAgreement();
ecdh.Init(_credentials.PrivateKey);
byte[] keyRaw;
try
{
keyRaw = ecdh.CalculateAgreement(remotePubKey).ToByteArrayUnsigned();
}
catch
(Exception e)
{
_error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}");
return;
}
KeyParameter key = new KeyParameter(keyRaw);
// generate a starting nonce
_nonce = GenerateStartingNonce();
// we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance
// instead of creating a new one each encrypt/decrypt
_cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, _nonce);
_cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce);
}
/**
* non-ready connections need to be ticked for resending key data over unreliable
*/
public void TickNonReady(double time)
{
if (IsReady)
{
return;
}
// Timeout reset
if (_handshakeTimeout == 0)
{
_handshakeTimeout = time + DurationTimeout;
}
else if (time > _handshakeTimeout)
{
_error?.Invoke(TransportError.Timeout, $"Timed out during {_state}, this probably just means the other side went away which is fine.");
return;
}
// Timeout reset
if (_nextHandshakeResend < 0)
{
_nextHandshakeResend = time + DurationResend;
return;
}
if (time < _nextHandshakeResend)
{
// Resend isn't due yet
return;
}
_nextHandshakeResend = time + DurationResend;
switch (_state)
{
case State.WaitingHandshake:
if (_sendsFirst)
{
SendHandshakeAndPubKey(OpCodes.HandshakeStart);
}
break;
case State.WaitingHandshakeReply:
if (_sendsFirst)
{
SendHandshakeFin();
}
else
{
SendHandshakeAndPubKey(OpCodes.HandshakeAck);
}
break;
case State.Ready: // IsReady is checked above & early-returned
default:
throw new ArgumentOutOfRangeException();
}
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 28f3ac4ff1d346a895d0b4ff714fb57b
timeCreated: 1708111337

View File

@ -0,0 +1,125 @@
using System;
using System.IO;
using Org.BouncyCastle.Asn1.Pkcs;
using Org.BouncyCastle.Asn1.X509;
using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Digests;
using Org.BouncyCastle.Crypto.Generators;
using Org.BouncyCastle.X509;
using Org.BouncyCastle.Crypto.Parameters;
using Org.BouncyCastle.Pkcs;
using Org.BouncyCastle.Security;
using UnityEngine;
namespace Mirror.Transports.Encryption
{
public class EncryptionCredentials
{
const int PrivateKeyBits = 256;
// don't actually need to store this currently
// but we'll need to for loading/saving from file maybe?
// public ECPublicKeyParameters PublicKey;
// The serialized public key, in DER format
public byte[] PublicKeySerialized;
public ECPrivateKeyParameters PrivateKey;
public string PublicKeyFingerprint;
EncryptionCredentials() {}
// TODO: load from file
public static EncryptionCredentials Generate()
{
var generator = new ECKeyPairGenerator();
generator.Init(new KeyGenerationParameters(new SecureRandom(), PrivateKeyBits));
AsymmetricCipherKeyPair keyPair = generator.GenerateKeyPair();
var serialized = SerializePublicKey((ECPublicKeyParameters)keyPair.Public);
return new EncryptionCredentials
{
// see fields above
// PublicKey = (ECPublicKeyParameters)keyPair.Public,
PublicKeySerialized = serialized,
PublicKeyFingerprint = PubKeyFingerprint(new ArraySegment<byte>(serialized)),
PrivateKey = (ECPrivateKeyParameters)keyPair.Private
};
}
public static byte[] SerializePublicKey(AsymmetricKeyParameter publicKey)
{
// apparently the best way to transmit this public key over the network is to serialize it as a DER
SubjectPublicKeyInfo publicKeyInfo = SubjectPublicKeyInfoFactory.CreateSubjectPublicKeyInfo(publicKey);
return publicKeyInfo.ToAsn1Object().GetDerEncoded();
}
public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment<byte> pubKey)
{
// And then we do this to deserialize from the DER (from above)
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
// to a byte[] first and then shoved through a MemoryStream
return PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false));
}
public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey)
{
// Serialize privateKey as a DER
PrivateKeyInfo privateKeyInfo = PrivateKeyInfoFactory.CreatePrivateKeyInfo(privateKey);
return privateKeyInfo.ToAsn1Object().GetDerEncoded();
}
public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment<byte> privateKey)
{
// And then we do this to deserialize from the DER (from above)
// the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted
// to a byte[] first and then shoved through a MemoryStream
return PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false));
}
public static string PubKeyFingerprint(ArraySegment<byte> publicKeyBytes)
{
Sha256Digest digest = new Sha256Digest();
byte[] hash = new byte[digest.GetDigestSize()];
digest.BlockUpdate(publicKeyBytes.Array, publicKeyBytes.Offset, publicKeyBytes.Count);
digest.DoFinal(hash, 0);
return BitConverter.ToString(hash).Replace("-", "").ToLowerInvariant();
}
public void SaveToFile(string path)
{
string json = JsonUtility.ToJson(new SerializedPair
{
PublicKeyFingerprint = PublicKeyFingerprint,
PublicKey = Convert.ToBase64String(PublicKeySerialized),
PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)),
});
File.WriteAllText(path, json);
}
public static EncryptionCredentials LoadFromFile(string path)
{
string json = File.ReadAllText(path);
SerializedPair serializedPair = JsonUtility.FromJson<SerializedPair>(json);
byte[] publicKeyBytes = Convert.FromBase64String(serializedPair.PublicKey);
byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey);
if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment<byte>(publicKeyBytes)))
{
throw new Exception("Saved public key fingerprint does not match public key.");
}
return new EncryptionCredentials
{
PublicKeySerialized = publicKeyBytes,
PublicKeyFingerprint = serializedPair.PublicKeyFingerprint,
PrivateKey = (ECPrivateKeyParameters) DeserializePrivateKey(new ArraySegment<byte>(privateKeyBytes))
};
}
private class SerializedPair
{
public string PublicKeyFingerprint;
public string PublicKey;
public string PrivateKey;
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: af6ae5f74f9548588cba5731643fabaf
timeCreated: 1708139579

View File

@ -0,0 +1,265 @@
using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Profiling;
using UnityEngine.Serialization;
namespace Mirror.Transports.Encryption
{
public class EncryptionTransport : Transport
{
public Transport inner;
public enum ValidationMode
{
Off,
List,
Callback,
}
public ValidationMode clientValidateServerPubKey;
[Tooltip("List of public key fingerprints the client will accept")]
public string[] clientTrustedPubKeySignatures;
public Func<PubKeyInfo, bool> onClientValidateServerPubKey;
public bool serverLoadKeyPairFromFile;
public string serverKeypairPath = "./server-keys.json";
private EncryptedConnection _client;
private Dictionary<int, EncryptedConnection> _serverConnections = new Dictionary<int, EncryptedConnection>();
private List<EncryptedConnection> _serverPendingConnections =
new List<EncryptedConnection>();
private EncryptionCredentials _credentials;
private void ServerRemoveFromPending(EncryptedConnection con)
{
for (int i = 0; i < _serverPendingConnections.Count; i++)
{
if (_serverPendingConnections[i] == con)
{
// remove by swapping with last
int lastIndex = _serverPendingConnections.Count - 1;
_serverPendingConnections[i] = _serverPendingConnections[lastIndex];
_serverPendingConnections.RemoveAt(lastIndex);
break;
}
}
}
private void HandleInnerServerDisconnected(int connId)
{
if (_serverConnections.TryGetValue(connId, out EncryptedConnection con))
{
ServerRemoveFromPending(con);
_serverConnections.Remove(connId);
}
OnServerDisconnected?.Invoke(connId);
}
private void HandleInnerServerError(int connId, TransportError type, string msg)
{
OnServerError?.Invoke(connId, type, $"inner: {msg}");
}
private void HandleInnerServerDataReceived(int connId, ArraySegment<byte> data, int channel)
{
if (_serverConnections.TryGetValue(connId, out EncryptedConnection c))
{
c.OnReceiveRaw(data, channel);
}
}
private void HandleInnerServerConnected(int connId)
{
Debug.Log($"[EncryptionTransport] New connection #{connId}");
EncryptedConnection ec = null;
ec = new EncryptedConnection(
_credentials,
false,
(segment, channel) => inner.ServerSend(connId, segment, channel),
(segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel),
() =>
{
Debug.Log($"[EncryptionTransport] Connection #{connId} is ready");
ServerRemoveFromPending(ec);
OnServerConnected?.Invoke(connId);
},
(type, msg) =>
{
OnServerError?.Invoke(connId, type, msg);
ServerDisconnect(connId);
});
_serverConnections.Add(connId, ec);
_serverPendingConnections.Add(ec);
}
private void HandleInnerClientDisconnected()
{
_client = null;
OnClientDisconnected?.Invoke();
}
private void HandleInnerClientError(TransportError arg1, string arg2)
{
OnClientError?.Invoke(arg1, $"inner: {arg2}");
}
private void HandleInnerClientDataReceived(ArraySegment<byte> data, int channel)
{
_client?.OnReceiveRaw(data, channel);
}
private void HandleInnerClientConnected()
{
_client = new EncryptedConnection(
_credentials,
true,
(segment, channel) => inner.ClientSend(segment, channel),
(segment, channel) => OnClientDataReceived?.Invoke(segment, channel),
() =>
{
OnClientConnected?.Invoke();
},
(type, msg) =>
{
OnClientError?.Invoke(type, msg);
ClientDisconnect();
},
HandleClientValidateServerPubKey);
}
private bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo)
{
switch (clientValidateServerPubKey)
{
case ValidationMode.Off:
return true;
case ValidationMode.List:
return Array.IndexOf(clientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0;
case ValidationMode.Callback:
return onClientValidateServerPubKey(pubKeyInfo);
default:
throw new ArgumentOutOfRangeException();
}
}
public override bool Available() => inner.Available();
public override bool ClientConnected() => _client != null && _client.IsReady;
public override void ClientConnect(string address)
{
switch (clientValidateServerPubKey)
{
case ValidationMode.Off:
break;
case ValidationMode.List:
if (clientTrustedPubKeySignatures == null || clientTrustedPubKeySignatures.Length == 0)
{
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty.");
return;
}
break;
case ValidationMode.Callback:
if (onClientValidateServerPubKey == null)
{
OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set");
return;
}
break;
default:
throw new ArgumentOutOfRangeException();
}
_credentials = EncryptionCredentials.Generate();
inner.OnClientConnected = HandleInnerClientConnected;
inner.OnClientDataReceived = HandleInnerClientDataReceived;
inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel);
inner.OnClientError = HandleInnerClientError;
inner.OnClientDisconnected = HandleInnerClientDisconnected;
inner.ClientConnect(address);
}
public override void ClientSend(ArraySegment<byte> segment, int channelId = Channels.Reliable) =>
_client?.Send(segment, channelId);
public override void ClientDisconnect() => inner.ClientDisconnect();
public override Uri ServerUri() => inner.ServerUri();
public override bool ServerActive() => inner.ServerActive();
public override void ServerStart()
{
if (serverLoadKeyPairFromFile)
{
_credentials = EncryptionCredentials.LoadFromFile(serverKeypairPath);
}
else
{
_credentials = EncryptionCredentials.Generate();
}
inner.OnServerConnected = HandleInnerServerConnected;
inner.OnServerDataReceived = HandleInnerServerDataReceived;
inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel);
inner.OnServerError = HandleInnerServerError;
inner.OnServerDisconnected = HandleInnerServerDisconnected;
inner.ServerStart();
}
public override void ServerSend(int connectionId, ArraySegment<byte> segment, int channelId = Channels.Reliable)
{
if (_serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady)
{
connection.Send(segment, channelId);
}
}
public override void ServerDisconnect(int connectionId)
{
// cleanup is done via inners disconnect event
inner.ServerDisconnect(connectionId);
}
public override string ServerGetClientAddress(int connectionId) => inner.ServerGetClientAddress(connectionId);
public override void ServerStop() => inner.ServerStop();
public override int GetMaxPacketSize(int channelId = Channels.Reliable) =>
inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead;
public override void Shutdown() => inner.Shutdown();
public override void ClientEarlyUpdate()
{
inner.ClientEarlyUpdate();
}
public override void ClientLateUpdate()
{
inner.ClientLateUpdate();
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
_client?.TickNonReady(NetworkTime.localTime);
Profiler.EndSample();
}
public override void ServerEarlyUpdate()
{
inner.ServerEarlyUpdate();
}
public override void ServerLateUpdate()
{
inner.ServerLateUpdate();
Profiler.BeginSample("EncryptionTransport.ServerLateUpdate");
// Reverse iteration as entries can be removed while updating
for (int i = _serverPendingConnections.Count - 1; i >= 0; i--)
{
_serverPendingConnections[i].TickNonReady(NetworkTime.time);
}
Profiler.EndSample();
}
}
}

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 0aa135acc32a4383ae9a5817f018cb06
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,9 @@
using System;
using Org.BouncyCastle.Crypto;
public struct PubKeyInfo
{
public string Fingerprint;
public ArraySegment<byte> Serialized;
public AsymmetricKeyParameter Key;
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: e1e3744418024c02acf39f44c1d1bd20
timeCreated: 1708874062