diff --git a/Assets/Mirror/Plugins/BouncyCastle.meta b/Assets/Mirror/Plugins/BouncyCastle.meta new file mode 100644 index 000000000..e064e80ff --- /dev/null +++ b/Assets/Mirror/Plugins/BouncyCastle.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 31ff83bf6d2e72542adcbe2c21383f4a +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirror/Plugins/BouncyCastle/BouncyCastle.Cryptography.dll b/Assets/Mirror/Plugins/BouncyCastle/BouncyCastle.Cryptography.dll new file mode 100644 index 000000000..9d9b6ac33 Binary files /dev/null and b/Assets/Mirror/Plugins/BouncyCastle/BouncyCastle.Cryptography.dll differ diff --git a/Assets/Mirror/Plugins/BouncyCastle/BouncyCastle.Cryptography.dll.meta b/Assets/Mirror/Plugins/BouncyCastle/BouncyCastle.Cryptography.dll.meta new file mode 100644 index 000000000..69befd70b --- /dev/null +++ b/Assets/Mirror/Plugins/BouncyCastle/BouncyCastle.Cryptography.dll.meta @@ -0,0 +1,33 @@ +fileFormatVersion: 2 +guid: a67bf078294b36e4686b9912bf172010 +PluginImporter: + externalObjects: {} + serializedVersion: 2 + iconMap: {} + executionOrder: {} + defineConstraints: [] + isPreloaded: 0 + isOverridable: 0 + isExplicitlyReferenced: 0 + validateReferences: 1 + platformData: + - first: + Any: + second: + enabled: 1 + settings: {} + - first: + Editor: Editor + second: + enabled: 0 + settings: + DefaultValueInitialized: true + - first: + Windows Store Apps: WindowsStoreApps + second: + enabled: 0 + settings: + CPU: AnyCPU + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirror/Plugins/BouncyCastle/LICENSE.md b/Assets/Mirror/Plugins/BouncyCastle/LICENSE.md new file mode 100644 index 000000000..277dcd1eb --- /dev/null +++ b/Assets/Mirror/Plugins/BouncyCastle/LICENSE.md @@ -0,0 +1,13 @@ +Copyright (c) 2000-2024 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org). +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, +sub license, and/or sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: The above copyright notice and this +permission notice shall be included in all copies or substantial portions of the Software. + +**THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT +OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.** diff --git a/Assets/Mirror/Plugins/BouncyCastle/LICENSE.md.meta b/Assets/Mirror/Plugins/BouncyCastle/LICENSE.md.meta new file mode 100644 index 000000000..d0ce88350 --- /dev/null +++ b/Assets/Mirror/Plugins/BouncyCastle/LICENSE.md.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 2b45a99b5583cda419e1f1ec943fec4b +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirror/Tests/Editor/Mirror.Tests.asmdef b/Assets/Mirror/Tests/Editor/Mirror.Tests.asmdef index 6e59124fd..9132b7df6 100644 --- a/Assets/Mirror/Tests/Editor/Mirror.Tests.asmdef +++ b/Assets/Mirror/Tests/Editor/Mirror.Tests.asmdef @@ -24,7 +24,8 @@ "Castle.Core.dll", "System.Threading.Tasks.Extensions.dll", "Mono.CecilX.dll", - "nunit.framework.dll" + "nunit.framework.dll", + "BouncyCastle.Cryptography.dll" ], "autoReferenced": false, "defineConstraints": [ diff --git a/Assets/Mirror/Tests/Editor/Tests.meta b/Assets/Mirror/Tests/Editor/Tests.meta new file mode 100644 index 000000000..feb8c89cc --- /dev/null +++ b/Assets/Mirror/Tests/Editor/Tests.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d5266c80d88c1ca4cb68cf0551780c3f +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportConnectionTest.cs b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportConnectionTest.cs new file mode 100644 index 000000000..4e505388a --- /dev/null +++ b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportConnectionTest.cs @@ -0,0 +1,549 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Mirror.Tests.NetworkServers; +using Mirror.Transports.Encryption; +using NUnit.Framework; + +namespace Mirror.Tests.Transports +{ + public class EncryptionTransportConnectionTest + { + struct Data + { + public byte[] data; + public int channel; + } + private EncryptedConnection server; + private EncryptionCredentials serverCreds; + Queue serverRecv = new Queue(); + private Action serverReady; + private Action, int> serverReceive; + private Func, int, bool> shouldServerSend; + private Func serverValidateKey; + + private EncryptedConnection client; + private EncryptionCredentials clientCreds; + Queue clientRecv = new Queue(); + private Action clientReady; + private Action, int> clientReceive; + private Func, int, bool> shouldClientSend; + private Func clientValidateKey; + + private double _time; + private double _timestep = 0.05; + class ErrorException : Exception + { + public ErrorException(string msg) : base(msg) {} + } + + [SetUp] + public void Setup() + { + serverReady = null; + serverReceive = null; + shouldServerSend = null; + serverValidateKey = null; + clientReady = null; + clientReceive = null; + shouldClientSend = null; + clientValidateKey = null; + clientRecv.Clear(); + serverRecv.Clear(); + + serverCreds = EncryptionCredentials.Generate(); + server = new EncryptedConnection(serverCreds, false, + (bytes, channel) => + { + if (shouldServerSend == null || shouldServerSend(bytes, channel)) + clientRecv.Enqueue(new Data + { + data = bytes.ToArray(), channel = channel + }); + }, + (bytes, channel) => + { + serverReceive?.Invoke(bytes, channel); + }, + () => { serverReady?.Invoke(); }, + (error, s) => throw new ErrorException($"{error}: {s}"), + info => + { + if (serverValidateKey != null) return serverValidateKey(info); + return true; + }); + + clientCreds = EncryptionCredentials.Generate(); + client = new EncryptedConnection(clientCreds, true, + (bytes, channel) => + { + if (shouldClientSend == null || shouldClientSend(bytes, channel)) + serverRecv.Enqueue(new Data + { + data = bytes.ToArray(), channel = channel + }); + }, + (bytes, channel) => + { + clientReceive?.Invoke(bytes, channel); + }, + () => { clientReady?.Invoke(); }, + (error, s) => throw new ErrorException($"{error}: {s}. t={_time}"), + info => + { + if (clientValidateKey != null) return clientValidateKey(info); + return true; + }); + } + + private void Pump() + { + _time += _timestep; + + while (clientRecv.TryDequeue(out Data data)) + { + client.OnReceiveRaw(new ArraySegment(data.data), data.channel); + } + if (!client.IsReady) + { + client.TickNonReady(_time); + } + + while (serverRecv.TryDequeue(out Data data)) + { + server.OnReceiveRaw(new ArraySegment(data.data), data.channel); + } + if (!server.IsReady) + { + server.TickNonReady(_time); + } + } + [TearDown] + public void TearDown() + { + } + + [Test] + public void TestHandshakeSuccess() + { + bool isServerReady = false; + bool isClientReady = false; + clientReady = () => + { + Assert.False(isClientReady); // only called once + Assert.True(client.IsReady); // should be set when called + isClientReady = true; + }; + serverReady = () => + { + Assert.False(isServerReady); // only called once + Assert.True(server.IsReady); // should be set when called + isServerReady = true; + server.Send(new ArraySegment(new byte[] + { + 1, 2, 3 + }), Channels.Reliable); // need to send to ready the other side + }; + + while (!isServerReady || !isClientReady) + { + if (_time > 20) + { + throw new Exception("Timeout."); + } + Pump(); + } + } + + [Test] + public void TestHandshakeSuccessWithLoss() + { + int clientCount = 0; + shouldClientSend = (data, channel) => + { + if (channel == Channels.Unreliable) + { + clientCount++; + // drop 75% of packets + return clientCount % 4 == 0; + } + return true; + }; + int serverCount = 0; + shouldServerSend = (data, channel) => + { + if (channel == Channels.Unreliable) + { + serverCount++; + // drop 75% of packets + return serverCount % 4 == 0; + } + return true; + }; + TestHandshakeSuccess(); + } + + private bool ArrayContainsSequence(ArraySegment haystack, ArraySegment needle) + { + if (needle.Count == 0) + { + return true; + } + int ni = 0; + for (int hi = 0; hi < haystack.Count; hi++) + { + if (haystack.Array[haystack.Offset + hi] == needle.Array[needle.Offset + ni]) + { + ni++; + if (ni == needle.Count) + { + return true; + } + } + else + { + ni = 0; + } + } + return false; + } + + [Test] + public void TestUtil() + { + Assert.True(ArrayContainsSequence(new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }), new ArraySegment(new byte[] + { + }))); + Assert.True(ArrayContainsSequence(new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }), new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }))); + Assert.True(ArrayContainsSequence(new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }), new ArraySegment(new byte[] + { + 2, 3 + }))); + Assert.True(ArrayContainsSequence(new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }), new ArraySegment(new byte[] + { + 3, 4 + }))); + Assert.False(ArrayContainsSequence(new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }), new ArraySegment(new byte[] + { + 1, 3 + }))); + Assert.False(ArrayContainsSequence(new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }), new ArraySegment(new byte[] + { + 3, 4, 5 + }))); + + } + [Test] + public void TestDataSecurity() + { + byte[] serverData = Encoding.UTF8.GetBytes("This is very important secret server data"); + byte[] clientData = Encoding.UTF8.GetBytes("Super secret data from the client is contained within."); + bool isServerDone = false; + bool isClientDone = false; + clientReady = () => + { + client.Send(new ArraySegment(clientData), Channels.Reliable); + }; + serverReady = () => + { + server.Send(new ArraySegment(serverData), Channels.Reliable); + }; + + shouldServerSend = (bytes, i) => + { + if (i == Channels.Reliable) + { + Assert.False(ArrayContainsSequence(bytes, new ArraySegment(serverData))); + } + return true; + }; + shouldClientSend = (bytes, i) => + { + if (i == Channels.Reliable) + { + Assert.False(ArrayContainsSequence(bytes, new ArraySegment(clientData))); + } + return true; + }; + + serverReceive = (bytes, channel) => + { + Assert.AreEqual(Channels.Reliable, channel); + Assert.AreEqual(bytes, new ArraySegment(clientData)); + Assert.False(isServerDone); + isServerDone = true; + }; + clientReceive = (bytes, channel) => + { + Assert.AreEqual(Channels.Reliable, channel); + Assert.AreEqual(bytes, new ArraySegment(serverData)); + Assert.False(isClientDone); + isClientDone = true; + }; + + while (!isServerDone || !isClientDone) + { + if (_time > 20) + { + throw new Exception("Timeout."); + } + Pump(); + } + } + + [Test] + public void TestBadOpCodeErrors() + { + Assert.Throws(() => + { + shouldServerSend = (bytes, i) => + { + // mess up the opcode (first byte) + bytes.Array[bytes.Offset] += 0xAA; + return true; + }; + // setup + TestHandshakeSuccess(); + }); + } + [Test] + public void TestEarlyDataOpCodeErrors() + { + Assert.Throws(() => + { + shouldServerSend = (bytes, i) => + { + // mess up the opcode (first byte) + bytes.Array[bytes.Offset] = 1; // data + return true; + }; + // setup + TestHandshakeSuccess(); + }); + } + + [Test] + public void TestUnexpectedAckOpCodeErrors() + { + Assert.Throws(() => + { + shouldServerSend = (bytes, i) => + { + // mess up the opcode (first byte) + bytes.Array[bytes.Offset] = 2; // start, client doesn't expect this + return true; + }; + // setup + TestHandshakeSuccess(); + }); + } + + [Test] + public void TestUnexpectedHandshakeOpCodeErrors() + { + Assert.Throws(() => + { + shouldClientSend = (bytes, i) => + { + // mess up the opcode (first byte) + bytes.Array[bytes.Offset] = 3; // ack, server doesn't expect this + return true; + }; + // setup + TestHandshakeSuccess(); + }); + } + [Test] + public void TestUnexpectedFinOpCodeErrors() + { + Assert.Throws(() => + { + shouldServerSend = (bytes, i) => + { + // mess up the opcode (first byte) + bytes.Array[bytes.Offset] = 4; // fin, client doesn't expect this + return true; + }; + // setup + TestHandshakeSuccess(); + }); + } + [Test] + public void TestBadDataErrors() + { + TestHandshakeSuccess(); + Assert.Throws(() => + { + // setup + shouldServerSend = (bytes, i) => + { + // mess up a byte in the data + bytes.Array[bytes.Offset + 3] += 1; + return true; + }; + server.Send(new ArraySegment(new byte[] + { + 1, 2, 3, 4 + }), Channels.Reliable); + Pump(); + }); + } + + [Test] + public void TestBadPubKeyInStartErrors() + { + shouldClientSend = (bytes, i) => + { + if (bytes.Array[bytes.Offset] == 2 /* HandshakeStart Opcode */) + { + // mess up a byte in the data + bytes.Array[bytes.Offset + 3] += 1; + } + return true; + }; + Assert.Throws(() => + { + TestHandshakeSuccess(); + }); + } + + [Test] + public void TestBadPubKeyInAckErrors() + { + shouldServerSend = (bytes, i) => + { + if (bytes.Array[bytes.Offset] == 3 /* HandshakeAck Opcode */) + { + // mess up a byte in the data + bytes.Array[bytes.Offset + 3] += 1; + } + return true; + }; + Assert.Throws(() => + { + TestHandshakeSuccess(); + }); + } + + [Test] + public void TestDataSizes() + { + List sizes = new List(); + sizes.Add(1); + sizes.Add(2); + sizes.Add(3); + sizes.Add(6); + sizes.Add(9); + sizes.Add(16); + sizes.Add(60); + sizes.Add(100); + sizes.Add(200); + sizes.Add(400); + sizes.Add(800); + sizes.Add(1024); + sizes.Add(1025); + sizes.Add(4096); + sizes.Add(1024 * 16); + sizes.Add(1024 * 64); + sizes.Add(1024 * 128); + sizes.Add(1024 * 512); + // removed for performance, these do pass though + //sizes.Add(1024 * 1024); + //sizes.Add(1024 * 1024 * 16); + //sizes.Add(1024 * 1024 * 64); // 64MiB + + TestHandshakeSuccess(); + var maxSize = sizes.Max(); + var sendByte = new byte[maxSize]; + for (uint i = 0; i < sendByte.Length; i++) + { + sendByte[i] = (byte)i; + } + int size = -1; + clientReceive = (bytes, channel) => + { + // Assert.AreEqual is super slow for larger arrays, so do it manually + Assert.AreEqual(bytes.Count, size); + for (int i = 0; i < size; i++) + { + if (bytes.Array[bytes.Offset + i] != sendByte[i]) + { + Assert.Fail($"received bytes[{i}] did not match. expected {sendByte[i]}, got {bytes.Array[bytes.Offset + i]}"); + } + } + }; + foreach (var s in sizes) + { + size = s; + server.Send(new ArraySegment(sendByte, 0, size), 1); + Pump(); + } + } + + + [Test] + public void TestPubKeyValidationIsCalled() + { + bool clientCalled = false; + clientValidateKey = info => + { + Assert.AreEqual(new ArraySegment(serverCreds.PublicKeySerialized), info.Serialized); + Assert.AreEqual(serverCreds.PublicKeyFingerprint, info.Fingerprint); + clientCalled = true; + return true; + }; + bool serverCalled = false; + serverValidateKey = info => + { + Assert.AreEqual(clientCreds.PublicKeyFingerprint, info.Fingerprint); + serverCalled = true; + return true; + }; + TestHandshakeSuccess(); + Assert.IsTrue(clientCalled); + Assert.IsTrue(serverCalled); + } + + [Test] + public void TestClientPubKeyValidationErrors() + { + clientValidateKey = info => false; + Assert.Throws(() => + { + TestHandshakeSuccess(); + }); + } + + [Test] + public void TestServerPubKeyValidationErrors() + { + serverValidateKey = info => false; + Assert.Throws(() => + { + TestHandshakeSuccess(); + }); + } + } +} diff --git a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportConnectionTest.cs.meta b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportConnectionTest.cs.meta new file mode 100644 index 000000000..438a37a4c --- /dev/null +++ b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportConnectionTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 6132bc4b559a42b88bd94cc25e1390bf +timeCreated: 1708170265 \ No newline at end of file diff --git a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs new file mode 100644 index 000000000..269ce1b15 --- /dev/null +++ b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs @@ -0,0 +1,234 @@ +using System; +using Mirror.Transports.Encryption; +using NSubstitute; +using NUnit.Framework; +using UnityEngine; + +namespace Mirror.Tests.Transports +{ + + // This is mostly a copy of MiddlewareTransport, with the stuff requiring actual connections to be setup deleted + [Description("Test to make sure inner methods are called when using Encryption Transport")] + public class EncryptionTransportTransportTest + { + Transport inner; + EncryptionTransport encryption; + + [SetUp] + public void Setup() + { + inner = Substitute.For(); + + GameObject gameObject = new GameObject(); + + encryption = gameObject.AddComponent(); + encryption.inner = inner; + } + + [TearDown] + public void TearDown() + { + GameObject.DestroyImmediate(encryption.gameObject); + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void TestAvailable(bool available) + { + inner.Available().Returns(available); + + Assert.That(encryption.Available(), Is.EqualTo(available)); + + inner.Received(1).Available(); + } + + [Test] + [TestCase(Channels.Reliable, 4000)] + [TestCase(Channels.Reliable, 2000)] + [TestCase(Channels.Unreliable, 4000)] + public void TestGetMaxPacketSize(int channel, int packageSize) + { + inner.GetMaxPacketSize(Arg.Any()).Returns(packageSize); + + Assert.That(encryption.GetMaxPacketSize(channel), Is.EqualTo(packageSize - EncryptedConnection.Overhead)); + + inner.Received(1).GetMaxPacketSize(Arg.Is(x => x == channel)); + inner.Received(0).GetMaxPacketSize(Arg.Is(x => x != channel)); + } + + [Test] + public void TestShutdown() + { + encryption.Shutdown(); + + inner.Received(1).Shutdown(); + } + + [Test] + [TestCase("localhost")] + [TestCase("example.com")] + public void TestClientConnect(string address) + { + encryption.ClientConnect(address); + + inner.Received(1).ClientConnect(address); + inner.Received(0).ClientConnect(Arg.Is(x => x != address)); + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void TestClientConnected(bool connected) + { + inner.ClientConnected().Returns(connected); + + Assert.That(encryption.ClientConnected(), Is.EqualTo(false)); // not testing connection handshaking here + } + + [Test] + public void TestClientDisconnect() + { + encryption.ClientDisconnect(); + + inner.Received(1).ClientDisconnect(); + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void TestServerActive(bool active) + { + inner.ServerActive().Returns(active); + + Assert.That(encryption.ServerActive(), Is.EqualTo(active)); + + inner.Received(1).ServerActive(); + } + + [Test] + public void TestServerStart() + { + encryption.ServerStart(); + + inner.Received(1).ServerStart(); + } + + [Test] + public void TestServerStop() + { + encryption.ServerStop(); + + inner.Received(1).ServerStop(); + } + + [Test] + [TestCase(0, "tcp4://localhost:7777")] + [TestCase(19, "tcp4://example.com:7777")] + public void TestServerGetClientAddress(int id, string result) + { + inner.ServerGetClientAddress(id).Returns(result); + + Assert.That(encryption.ServerGetClientAddress(id), Is.EqualTo(result)); + + inner.Received(1).ServerGetClientAddress(id); + inner.Received(0).ServerGetClientAddress(Arg.Is(x => x != id)); + + } + + [Test] + [TestCase("tcp4://localhost:7777")] + [TestCase("tcp4://example.com:7777")] + public void TestServerUri(string address) + { + Uri uri = new Uri(address); + inner.ServerUri().Returns(uri); + + Assert.That(encryption.ServerUri(), Is.EqualTo(uri)); + + inner.Received(1).ServerUri(); + } + + [Test] + public void TestClientDisconnectedCallback() + { + int called = 0; + encryption.OnClientDisconnected = () => + { + called++; + }; + // connect to give callback to inner + encryption.ClientConnect("localhost"); + + inner.OnClientDisconnected.Invoke(); + Assert.That(called, Is.EqualTo(1)); + + inner.OnClientDisconnected.Invoke(); + Assert.That(called, Is.EqualTo(2)); + } + + [Test] + public void TestClientErrorCallback() + { + int called = 0; + encryption.OnClientError = (error, reason) => + { + called++; + Assert.That(error, Is.EqualTo(TransportError.Unexpected)); + }; + // connect to give callback to inner + encryption.ClientConnect("localhost"); + + inner.OnClientError.Invoke(TransportError.Unexpected, ""); + Assert.That(called, Is.EqualTo(1)); + + inner.OnClientError.Invoke(TransportError.Unexpected, ""); + Assert.That(called, Is.EqualTo(2)); + } + + [Test] + [TestCase(0)] + [TestCase(1)] + [TestCase(19)] + public void TestServerDisconnectedCallback(int id) + { + int called = 0; + encryption.OnServerDisconnected = (i) => + { + called++; + Assert.That(i, Is.EqualTo(id)); + }; + // start to give callback to inner + encryption.ServerStart(); + + inner.OnServerDisconnected.Invoke(id); + Assert.That(called, Is.EqualTo(1)); + + inner.OnServerDisconnected.Invoke(id); + Assert.That(called, Is.EqualTo(2)); + } + + [Test] + [TestCase(0)] + [TestCase(1)] + [TestCase(19)] + public void TestServerErrorCallback(int id) + { + int called = 0; + encryption.OnServerError = (i, error, reason) => + { + called++; + Assert.That(i, Is.EqualTo(id)); + Assert.That(error, Is.EqualTo(TransportError.Unexpected)); + }; + // start to give callback to inner + encryption.ServerStart(); + + inner.OnServerError.Invoke(id, TransportError.Unexpected, ""); + Assert.That(called, Is.EqualTo(1)); + + inner.OnServerError.Invoke(id, TransportError.Unexpected, ""); + Assert.That(called, Is.EqualTo(2)); + } + } +} diff --git a/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs.meta b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs.meta new file mode 100644 index 000000000..8dea8ee28 --- /dev/null +++ b/Assets/Mirror/Tests/Editor/Transports/EncryptionTransportTransportTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 24f6aad90a7a4a42bba0473d5b27ebe8 +timeCreated: 1708386085 \ No newline at end of file diff --git a/Assets/Mirror/Transports/Encryption.meta b/Assets/Mirror/Transports/Encryption.meta new file mode 100644 index 000000000..6c507ab49 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 741b3c7e5d0842049ff50a2f6e27ca12 +timeCreated: 1708015148 \ No newline at end of file diff --git a/Assets/Mirror/Transports/Encryption/Editor.meta b/Assets/Mirror/Transports/Encryption/Editor.meta new file mode 100644 index 000000000..b6cf690a7 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/Editor.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 0d3cd9d7d6e84a578f7e4b384ff813f1 +timeCreated: 1708793986 \ No newline at end of file diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportEditor.asmdef b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportEditor.asmdef new file mode 100644 index 000000000..0ba9c7627 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportEditor.asmdef @@ -0,0 +1,18 @@ +{ + "name": "EncryptionTransportEditor", + "rootNamespace": "", + "references": [ + "GUID:627104647b9c04b4ebb8978a92ecac63" + ], + "includePlatforms": [ + "Editor" + ], + "excludePlatforms": [], + "allowUnsafeCode": false, + "overrideReferences": false, + "precompiledReferences": [], + "autoReferenced": true, + "defineConstraints": [], + "versionDefines": [], + "noEngineReferences": false +} \ No newline at end of file diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportEditor.asmdef.meta b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportEditor.asmdef.meta new file mode 100644 index 000000000..43ba20c69 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportEditor.asmdef.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 4c9c7b0ef83e6e945b276d644816a489 +AssemblyDefinitionImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs new file mode 100644 index 000000000..24557f910 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs @@ -0,0 +1,81 @@ +using UnityEditor; +using UnityEngine; + +namespace Mirror.Transports.Encryption +{ + [CustomEditor(typeof(EncryptionTransport), true)] + public class EncryptionTransportInspector : UnityEditor.Editor + { + SerializedProperty innerProperty; + SerializedProperty clientValidatesServerPubKeyProperty; + SerializedProperty clientTrustedPubKeySignaturesProperty; + SerializedProperty serverKeypairPathProperty; + SerializedProperty serverLoadKeyPairFromFileProperty; + + // Assuming proper SerializedProperty definitions for properties + // Add more SerializedProperty fields related to different modes as needed + + void OnEnable() + { + innerProperty = serializedObject.FindProperty("inner"); + clientValidatesServerPubKeyProperty = serializedObject.FindProperty("clientValidateServerPubKey"); + clientTrustedPubKeySignaturesProperty = serializedObject.FindProperty("clientTrustedPubKeySignatures"); + serverKeypairPathProperty = serializedObject.FindProperty("serverKeypairPath"); + serverLoadKeyPairFromFileProperty = serializedObject.FindProperty("serverLoadKeyPairFromFile"); + } + + public override void OnInspectorGUI() + { + serializedObject.Update(); + + EditorGUILayout.LabelField("Common", EditorStyles.boldLabel); + EditorGUILayout.PropertyField(innerProperty); + EditorGUILayout.Separator(); + // Client Section + EditorGUILayout.LabelField("Client", EditorStyles.boldLabel); + EditorGUILayout.HelpBox("Validating the servers public key is essential for complete man-in-the-middle (MITM) safety, but might not be feasible for all modes of hosting.", MessageType.Info); + EditorGUILayout.PropertyField(clientValidatesServerPubKeyProperty, new GUIContent("Validate Server Public Key")); + + EncryptionTransport.ValidationMode validationMode = (EncryptionTransport.ValidationMode)clientValidatesServerPubKeyProperty.enumValueIndex; + + switch (validationMode) + { + case EncryptionTransport.ValidationMode.List: + EditorGUILayout.PropertyField(clientTrustedPubKeySignaturesProperty); + break; + case EncryptionTransport.ValidationMode.Callback: + EditorGUILayout.HelpBox("Please set the EncryptionTransport.onClientValidateServerPubKey at runtime.", MessageType.Info); + break; + } + + EditorGUILayout.Separator(); + // Server Section + EditorGUILayout.LabelField("Server", EditorStyles.boldLabel); + EditorGUILayout.PropertyField(serverLoadKeyPairFromFileProperty, new GUIContent("Load Keypair From File")); + if (serverLoadKeyPairFromFileProperty.boolValue) + { + EditorGUILayout.PropertyField(serverKeypairPathProperty, new GUIContent("Keypair File Path")); + } + if(GUILayout.Button("Generate Key Pair")) + { + EncryptionCredentials keyPair = EncryptionCredentials.Generate(); + string path = EditorUtility.SaveFilePanel("Select where to save the keypair", "", "server-keys.json", "json"); + if (!string.IsNullOrEmpty(path)) + { + keyPair.SaveToFile(path); + EditorUtility.DisplayDialog("KeyPair Saved", $"Successfully saved the keypair.\nThe fingerprint is {keyPair.PublicKeyFingerprint}, you can also retrieve it from the saved json file at any point.", "Ok"); + if (validationMode == EncryptionTransport.ValidationMode.List) + { + if (EditorUtility.DisplayDialog("Add key to trusted list?", "Do you also want to add the generated key to the trusted list?", "Yes", "No")) + { + clientTrustedPubKeySignaturesProperty.arraySize++; + clientTrustedPubKeySignaturesProperty.GetArrayElementAtIndex(clientTrustedPubKeySignaturesProperty.arraySize - 1).stringValue = keyPair.PublicKeyFingerprint; + } + } + } + } + + serializedObject.ApplyModifiedProperties(); + } + } +} diff --git a/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs.meta b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs.meta new file mode 100644 index 000000000..9aad40bb3 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/Editor/EncryptionTransportInspector.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 871580d2094a46139279d651cec92b5d +timeCreated: 1708794004 \ No newline at end of file diff --git a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs new file mode 100644 index 000000000..48a8a3cae --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs @@ -0,0 +1,555 @@ +using System; +using System.Security.Cryptography; +using Org.BouncyCastle.Crypto; +using Org.BouncyCastle.Crypto.Agreement; +using Org.BouncyCastle.Crypto.Engines; +using Org.BouncyCastle.Crypto.Modes; +using Org.BouncyCastle.Crypto.Parameters; +using UnityEngine.Profiling; + +namespace Mirror.Transports.Encryption +{ + public class EncryptedConnection + { + // fixed size of the unique per-packet nonce. Defaults to 12 bytes/96 bits (not recommended to be changed) + private const int NonceSize = 12; + + // this is the size of the "checksum" included in each encrypted payload + // 16 bytes/128 bytes is the recommended value for best security + // can be reduced to 12 bytes for a small space savings, but makes encryption slightly weaker. + // Setting it lower than 12 bytes is not recommended + private const int MacSizeBytes = 16; + + private const int MacSizeBits = MacSizeBytes * 8; + + // How much metadata overhead we have for regular packets + public const int Overhead = sizeof(OpCodes) + MacSizeBytes + NonceSize; + + // After how many seconds of not receiving a handshake packet we should time out + private const double DurationTimeout = 2; // 2s + + // After how many seconds to assume the last handshake packet got lost and to resend another one + private const double DurationResend = 0.05; // 50ms + + + // Static fields for allocation efficiency, makes this not thread safe + // It'd be as easy as using ThreadLocal though to fix that + + // Set up a global cipher instance, it is initialised/reset before use + // (AesFastEngine used to exist, but was removed due to side channel issues) + // use AesUtilities.CreateEngine here as it'll pick the hardware accelerated one if available (which is will not be unless on .net core) + private static readonly GcmBlockCipher Cipher = new GcmBlockCipher(AesUtilities.CreateEngine()); + + // Global byte array to store nonce sent by the remote side, they're used immediately after + private static readonly byte[] ReceiveNonce = new byte[NonceSize]; + + // buffer for encrypt/decrypt operations, resized larger as needed + // this is also the buffer that will be returned to mirror via ArraySegment + // so any thread safety concerns would need to take extra care here + private static byte[] _tmpCryptBuffer = new byte[2048]; + + // packet headers + enum OpCodes : byte + { + // start at 1 to maybe filter out random noise + Data = 1, + HandshakeStart = 2, + HandshakeAck = 3, + HandshakeFin = 4, + } + + enum State + { + // Waiting for a handshake to arrive + // this is for _sendsFirst: + // - false: OpCodes.HandshakeStart + // - true: Opcodes.HandshakeAck + WaitingHandshake, + + // Waiting for a handshake reply/acknowledgement to arrive + // this is for _sendsFirst: + // - false: OpCodes.HandshakeFine + // - true: Opcodes.Data (implicitly) + WaitingHandshakeReply, + + // Both sides have confirmed the keys are exchanged and data can be sent freely + Ready + } + + private State _state = State.WaitingHandshake; + + // Key exchange confirmed and data can be sent freely + public bool IsReady => _state == State.Ready; + // Callback to send off encrypted data + private Action, int> _send; + // Callback when received data has been decrypted + private Action, int> _receive; + // Callback when the connection becomes ready + private Action _ready; + // On-error callback, disconnect expected + private Action _error; + // Optional callback to validate the remotes public key, validation on one side is necessary to ensure MITM resistance + // (usually client validates the server key) + private Func _validateRemoteKey; + // Our asymmetric credentials for the initial DH exchange + private EncryptionCredentials _credentials; + + // After no handshake packet in this many seconds, the handshake fails + private double _handshakeTimeout; + // When to assume the last handshake packet got lost and to resend another one + private double _nextHandshakeResend; + + + // we can reuse the _cipherParameters here since the nonce is stored as the byte[] reference we pass in + // so we can update it without creating a new AeadParameters instance + // this might break in the future! (will cause bad data) + private byte[] _nonce = new byte[NonceSize]; + private AeadParameters _cipherParametersEncrypt; + private AeadParameters _cipherParametersDecrypt; + + + /* + * Specifies if we send the first key, then receive ack, then send fin + * Or the opposite if set to false + * + * The client does this, since the fin is not acked explicitly, but by receiving data to decrypt + */ + private readonly bool _sendsFirst; + + public EncryptedConnection(EncryptionCredentials credentials, + bool isClient, + Action, int> sendAction, + Action, int> receiveAction, + Action readyAction, + Action errorAction, + Func validateRemoteKey = null) + { + _credentials = credentials; + _sendsFirst = isClient; + _send = sendAction; + _receive = receiveAction; + _ready = readyAction; + _error = errorAction; + _validateRemoteKey = validateRemoteKey; + } + + // Generates a random starting nonce + private static byte[] GenerateStartingNonce() + { + byte[] nonce = new byte[NonceSize]; + using (RandomNumberGenerator rng = RandomNumberGenerator.Create()) + { + rng.GetBytes(nonce); + } + + return nonce; + } + + public void OnReceiveRaw(ArraySegment data, int channel) + { + if (data.Count < 1) + { + _error(TransportError.Unexpected, "Received empty packet"); + return; + } + + using (NetworkReaderPooled reader = NetworkReaderPool.Get(data)) + { + OpCodes opcode = (OpCodes)reader.ReadByte(); + switch (opcode) + { + case OpCodes.Data: + // first sender ready is implicit when data is received + if (_sendsFirst && _state == State.WaitingHandshakeReply) + { + SetReady(); + } + else if (!IsReady) + { + _error(TransportError.Unexpected, "Unexpected data while not ready."); + } + + if (reader.Remaining < Overhead) + { + _error(TransportError.Unexpected, "received data packet smaller than metadata size"); + return; + } + + ArraySegment ciphertext = reader.ReadBytesSegment(reader.Remaining - NonceSize); + reader.ReadBytes(ReceiveNonce, NonceSize); + + Profiler.BeginSample("EncryptedConnection.Decrypt"); + ArraySegment plaintext = Decrypt(ciphertext); + Profiler.EndSample(); + if (plaintext.Count == 0) + { + // error + return; + } + _receive(plaintext, channel); + break; + case OpCodes.HandshakeStart: + if (_sendsFirst) + { + _error(TransportError.Unexpected, "Received HandshakeStart packet, we don't expect this."); + return; + } + + if (_state == State.WaitingHandshakeReply) + { + // this is fine, packets may arrive out of order + return; + } + + _state = State.WaitingHandshakeReply; + ResetTimeouts(); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining)); + SendHandshakeAndPubKey(OpCodes.HandshakeAck); + break; + case OpCodes.HandshakeAck: + if (!_sendsFirst) + { + _error(TransportError.Unexpected, "Received HandshakeAck packet, we don't expect this."); + return; + } + + if (IsReady) + { + // this is fine, packets may arrive out of order + return; + } + + if (_state == State.WaitingHandshakeReply) + { + // this is fine, packets may arrive out of order + return; + } + + + _state = State.WaitingHandshakeReply; + ResetTimeouts(); + CompleteExchange(reader.ReadBytesSegment(reader.Remaining)); + SendHandshakeFin(); + break; + case OpCodes.HandshakeFin: + if (_sendsFirst) + { + _error(TransportError.Unexpected, "Received HandshakeFin packet, we don't expect this."); + return; + } + + if (IsReady) + { + // this is fine, packets may arrive out of order + return; + } + + if (_state != State.WaitingHandshakeReply) + { + _error(TransportError.Unexpected, + "Received HandshakeFin packet, we didn't expect this yet."); + return; + } + + SetReady(); + + break; + default: + _error(TransportError.InvalidReceive, $"Unhandled opcode {(byte)opcode:x}"); + break; + } + } + } + private void SetReady() + { + // done with credentials, null out the reference + _credentials = null; + + _state = State.Ready; + _ready(); + } + + private void ResetTimeouts() + { + _handshakeTimeout = 0; + _nextHandshakeResend = -1; + } + + public void Send(ArraySegment data, int channel) + { + using (NetworkWriterPooled writer = NetworkWriterPool.Get()) + { + writer.WriteByte((byte)OpCodes.Data); + Profiler.BeginSample("EncryptedConnection.Encrypt"); + ArraySegment encrypted = Encrypt(data); + Profiler.EndSample(); + + if (encrypted.Count == 0) + { + // error + return; + } + writer.WriteBytes(encrypted.Array, 0, encrypted.Count); + // write nonce after since Encrypt will update it + writer.WriteBytes(_nonce, 0, NonceSize); + _send(writer.ToArraySegment(), channel); + } + } + + private ArraySegment Encrypt(ArraySegment plaintext) + { + if (plaintext.Count == 0) + { + // Invalid + return new ArraySegment(); + } + // Need to make the nonce unique again before encrypting another message + UpdateNonce(); + // Re-initialize the cipher with our cached parameters + Cipher.Init(true, _cipherParametersEncrypt); + + // Calculate the expected output size, this should always be input size + mac size + int outSize = Cipher.GetOutputSize(plaintext.Count); +#if UNITY_EDITOR + // expecting the outSize to be input size + MacSize + if (outSize != plaintext.Count + MacSizeBytes) + { + throw new Exception($"Encrypt: Unexpected output size (Expected {plaintext.Count + MacSizeBytes}, got {outSize}"); + } +#endif + // Resize the static buffer to fit + EnsureSize(ref _tmpCryptBuffer, outSize); + int resultLen; + try + { + // Run the plain text through the cipher, ProcessBytes will only process full blocks + resultLen = + Cipher.ProcessBytes(plaintext.Array, plaintext.Offset, plaintext.Count, _tmpCryptBuffer, 0); + // Then run any potentially remaining partial blocks through with DoFinal (and calculate the mac) + resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); + } + // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types + // + catch (Exception e) + { + _error(TransportError.Unexpected, $"Unexpected exception while encrypting {e.GetType()}: {e.Message}"); + return new ArraySegment(); + } +#if UNITY_EDITOR + // expecting the result length to match the previously calculated input size + MacSize + if (resultLen != outSize) + { + throw new Exception($"Encrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); + } +#endif + return new ArraySegment(_tmpCryptBuffer, 0, resultLen); + } + + private ArraySegment Decrypt(ArraySegment ciphertext) + { + if (ciphertext.Count <= MacSizeBytes) + { + _error(TransportError.Unexpected, $"Received too short data packet (min {{MacSizeBytes + 1}}, got {ciphertext.Count})"); + // Invalid + return new ArraySegment(); + } + // Re-initialize the cipher with our cached parameters + Cipher.Init(false, _cipherParametersDecrypt); + + // Calculate the expected output size, this should always be input size - mac size + int outSize = Cipher.GetOutputSize(ciphertext.Count); +#if UNITY_EDITOR + // expecting the outSize to be input size - MacSize + if (outSize != ciphertext.Count - MacSizeBytes) + { + throw new Exception($"Decrypt: Unexpected output size (Expected {ciphertext.Count - MacSizeBytes}, got {outSize}"); + } +#endif + // Resize the static buffer to fit + EnsureSize(ref _tmpCryptBuffer, outSize); + int resultLen; + try + { + // Run the ciphertext through the cipher, ProcessBytes will only process full blocks + resultLen = + Cipher.ProcessBytes(ciphertext.Array, ciphertext.Offset, ciphertext.Count, _tmpCryptBuffer, 0); + // Then run any potentially remaining partial blocks through with DoFinal (and calculate/check the mac) + resultLen += Cipher.DoFinal(_tmpCryptBuffer, resultLen); + } + // catch all Exception's since BouncyCastle is fairly noisy with both standard and their own exception types + catch (Exception e) + { + _error(TransportError.Unexpected, $"Unexpected exception while decrypting {e.GetType()}: {e.Message}. This usually signifies corrupt data"); + return new ArraySegment(); + } +#if UNITY_EDITOR + // expecting the result length to match the previously calculated input size + MacSize + if (resultLen != outSize) + { + throw new Exception($"Decrypt: resultLen did not match outSize (expected {outSize}, got {resultLen})"); + } +#endif + return new ArraySegment(_tmpCryptBuffer, 0, resultLen); + } + + private void UpdateNonce() + { + // increment the nonce by one + // we need to ensure the nonce is *always* unique and not reused + // easiest way to do this is by simply incrementing it + for (int i = 0; i < NonceSize; i++) + { + _nonce[i]++; + if (_nonce[i] != 0) + { + break; + } + } + } + + private static void EnsureSize(ref byte[] buffer, int size) + { + if (buffer.Length < size) + { + // double buffer to avoid constantly resizing by a few bytes + Array.Resize(ref buffer, Math.Max(size, buffer.Length * 2)); + } + } + + private void SendHandshakeAndPubKey(OpCodes opcode) + { + using (NetworkWriterPooled writer = NetworkWriterPool.Get()) + { + writer.WriteByte((byte)opcode); + writer.WriteBytes(_credentials.PublicKeySerialized, 0, _credentials.PublicKeySerialized.Length); + _send(writer.ToArraySegment(), Channels.Unreliable); + } + } + + private void SendHandshakeFin() + { + using (NetworkWriterPooled writer = NetworkWriterPool.Get()) + { + writer.WriteByte((byte)OpCodes.HandshakeFin); + _send(writer.ToArraySegment(), Channels.Unreliable); + } + } + + private void CompleteExchange(ArraySegment remotePubKeyRaw) + { + AsymmetricKeyParameter remotePubKey; + try + { + remotePubKey = EncryptionCredentials.DeserializePublicKey(remotePubKeyRaw); + } + catch (Exception e) + { + _error(TransportError.Unexpected, $"Failed to deserialize public key of remote. {e.GetType()}: {e.Message}"); + return; + } + + if (_validateRemoteKey != null) + { + PubKeyInfo info = new PubKeyInfo + { + Fingerprint = EncryptionCredentials.PubKeyFingerprint(remotePubKeyRaw), + Serialized = remotePubKeyRaw, + Key = remotePubKey + }; + if (!_validateRemoteKey(info)) + { + _error(TransportError.Unexpected, $"Remote public key (fingerprint: {info.Fingerprint}) failed validation. "); + return; + } + } + + // Calculate a common symmetric key from our private key and the remotes public key + // This gives us the same key on the other side, with our public key and their remote + // It's like magic, but with math! + ECDHBasicAgreement ecdh = new ECDHBasicAgreement(); + ecdh.Init(_credentials.PrivateKey); + byte[] keyRaw; + try + { + keyRaw = ecdh.CalculateAgreement(remotePubKey).ToByteArrayUnsigned(); + } + catch + (Exception e) + { + _error(TransportError.Unexpected, $"Failed to calculate the ECDH key exchange. {e.GetType()}: {e.Message}"); + return; + } + + KeyParameter key = new KeyParameter(keyRaw); + + // generate a starting nonce + _nonce = GenerateStartingNonce(); + + // we pass in the nonce array once (as it's stored by reference) so we can cache the AeadParameters instance + // instead of creating a new one each encrypt/decrypt + _cipherParametersEncrypt = new AeadParameters(key, MacSizeBits, _nonce); + _cipherParametersDecrypt = new AeadParameters(key, MacSizeBits, ReceiveNonce); + } + + /** + * non-ready connections need to be ticked for resending key data over unreliable + */ + public void TickNonReady(double time) + { + if (IsReady) + { + return; + } + + // Timeout reset + if (_handshakeTimeout == 0) + { + _handshakeTimeout = time + DurationTimeout; + } + else if (time > _handshakeTimeout) + { + _error?.Invoke(TransportError.Timeout, $"Timed out during {_state}, this probably just means the other side went away which is fine."); + return; + } + + // Timeout reset + if (_nextHandshakeResend < 0) + { + _nextHandshakeResend = time + DurationResend; + return; + } + + if (time < _nextHandshakeResend) + { + // Resend isn't due yet + return; + } + + _nextHandshakeResend = time + DurationResend; + switch (_state) + { + case State.WaitingHandshake: + if (_sendsFirst) + { + SendHandshakeAndPubKey(OpCodes.HandshakeStart); + } + + break; + case State.WaitingHandshakeReply: + if (_sendsFirst) + { + SendHandshakeFin(); + } + else + { + SendHandshakeAndPubKey(OpCodes.HandshakeAck); + } + + break; + case State.Ready: // IsReady is checked above & early-returned + default: + throw new ArgumentOutOfRangeException(); + } + } + } +} diff --git a/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs.meta b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs.meta new file mode 100644 index 000000000..b5e52091b --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/EncryptedConnection.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 28f3ac4ff1d346a895d0b4ff714fb57b +timeCreated: 1708111337 \ No newline at end of file diff --git a/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs new file mode 100644 index 000000000..2e0b04259 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs @@ -0,0 +1,125 @@ +using System; +using System.IO; +using Org.BouncyCastle.Asn1.Pkcs; +using Org.BouncyCastle.Asn1.X509; +using Org.BouncyCastle.Crypto; +using Org.BouncyCastle.Crypto.Digests; +using Org.BouncyCastle.Crypto.Generators; +using Org.BouncyCastle.X509; +using Org.BouncyCastle.Crypto.Parameters; +using Org.BouncyCastle.Pkcs; +using Org.BouncyCastle.Security; +using UnityEngine; + +namespace Mirror.Transports.Encryption +{ + public class EncryptionCredentials + { + const int PrivateKeyBits = 256; + // don't actually need to store this currently + // but we'll need to for loading/saving from file maybe? + // public ECPublicKeyParameters PublicKey; + + // The serialized public key, in DER format + public byte[] PublicKeySerialized; + public ECPrivateKeyParameters PrivateKey; + public string PublicKeyFingerprint; + + EncryptionCredentials() {} + + // TODO: load from file + public static EncryptionCredentials Generate() + { + var generator = new ECKeyPairGenerator(); + generator.Init(new KeyGenerationParameters(new SecureRandom(), PrivateKeyBits)); + AsymmetricCipherKeyPair keyPair = generator.GenerateKeyPair(); + var serialized = SerializePublicKey((ECPublicKeyParameters)keyPair.Public); + return new EncryptionCredentials + { + // see fields above + // PublicKey = (ECPublicKeyParameters)keyPair.Public, + PublicKeySerialized = serialized, + PublicKeyFingerprint = PubKeyFingerprint(new ArraySegment(serialized)), + PrivateKey = (ECPrivateKeyParameters)keyPair.Private + }; + } + + public static byte[] SerializePublicKey(AsymmetricKeyParameter publicKey) + { + // apparently the best way to transmit this public key over the network is to serialize it as a DER + SubjectPublicKeyInfo publicKeyInfo = SubjectPublicKeyInfoFactory.CreateSubjectPublicKeyInfo(publicKey); + return publicKeyInfo.ToAsn1Object().GetDerEncoded(); + } + + public static AsymmetricKeyParameter DeserializePublicKey(ArraySegment pubKey) + { + // And then we do this to deserialize from the DER (from above) + // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted + // to a byte[] first and then shoved through a MemoryStream + return PublicKeyFactory.CreateKey(new MemoryStream(pubKey.Array, pubKey.Offset, pubKey.Count, false)); + } + + public static byte[] SerializePrivateKey(AsymmetricKeyParameter privateKey) + { + // Serialize privateKey as a DER + PrivateKeyInfo privateKeyInfo = PrivateKeyInfoFactory.CreatePrivateKeyInfo(privateKey); + return privateKeyInfo.ToAsn1Object().GetDerEncoded(); + } + + public static AsymmetricKeyParameter DeserializePrivateKey(ArraySegment privateKey) + { + // And then we do this to deserialize from the DER (from above) + // the "new MemoryStream" actually saves an allocation, since otherwise the ArraySegment would be converted + // to a byte[] first and then shoved through a MemoryStream + return PrivateKeyFactory.CreateKey(new MemoryStream(privateKey.Array, privateKey.Offset, privateKey.Count, false)); + } + + public static string PubKeyFingerprint(ArraySegment publicKeyBytes) + { + Sha256Digest digest = new Sha256Digest(); + byte[] hash = new byte[digest.GetDigestSize()]; + digest.BlockUpdate(publicKeyBytes.Array, publicKeyBytes.Offset, publicKeyBytes.Count); + digest.DoFinal(hash, 0); + + return BitConverter.ToString(hash).Replace("-", "").ToLowerInvariant(); + } + + public void SaveToFile(string path) + { + string json = JsonUtility.ToJson(new SerializedPair + { + PublicKeyFingerprint = PublicKeyFingerprint, + PublicKey = Convert.ToBase64String(PublicKeySerialized), + PrivateKey= Convert.ToBase64String(SerializePrivateKey(PrivateKey)), + }); + File.WriteAllText(path, json); + } + + public static EncryptionCredentials LoadFromFile(string path) + { + string json = File.ReadAllText(path); + SerializedPair serializedPair = JsonUtility.FromJson(json); + + byte[] publicKeyBytes = Convert.FromBase64String(serializedPair.PublicKey); + byte[] privateKeyBytes = Convert.FromBase64String(serializedPair.PrivateKey); + + if (serializedPair.PublicKeyFingerprint != PubKeyFingerprint(new ArraySegment(publicKeyBytes))) + { + throw new Exception("Saved public key fingerprint does not match public key."); + } + return new EncryptionCredentials + { + PublicKeySerialized = publicKeyBytes, + PublicKeyFingerprint = serializedPair.PublicKeyFingerprint, + PrivateKey = (ECPrivateKeyParameters) DeserializePrivateKey(new ArraySegment(privateKeyBytes)) + }; + } + + private class SerializedPair + { + public string PublicKeyFingerprint; + public string PublicKey; + public string PrivateKey; + } + } +} diff --git a/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs.meta b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs.meta new file mode 100644 index 000000000..38f119773 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/EncryptionCredentials.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: af6ae5f74f9548588cba5731643fabaf +timeCreated: 1708139579 \ No newline at end of file diff --git a/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs new file mode 100644 index 000000000..6c55aa28f --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs @@ -0,0 +1,265 @@ +using System; +using System.Collections.Generic; +using UnityEngine; +using UnityEngine.Profiling; +using UnityEngine.Serialization; + +namespace Mirror.Transports.Encryption +{ + public class EncryptionTransport : Transport + { + public Transport inner; + + public enum ValidationMode + { + Off, + List, + Callback, + } + + public ValidationMode clientValidateServerPubKey; + [Tooltip("List of public key fingerprints the client will accept")] + public string[] clientTrustedPubKeySignatures; + public Func onClientValidateServerPubKey; + public bool serverLoadKeyPairFromFile; + public string serverKeypairPath = "./server-keys.json"; + + private EncryptedConnection _client; + + private Dictionary _serverConnections = new Dictionary(); + + private List _serverPendingConnections = + new List(); + + private EncryptionCredentials _credentials; + + private void ServerRemoveFromPending(EncryptedConnection con) + { + for (int i = 0; i < _serverPendingConnections.Count; i++) + { + if (_serverPendingConnections[i] == con) + { + // remove by swapping with last + int lastIndex = _serverPendingConnections.Count - 1; + _serverPendingConnections[i] = _serverPendingConnections[lastIndex]; + _serverPendingConnections.RemoveAt(lastIndex); + break; + } + } + } + + private void HandleInnerServerDisconnected(int connId) + { + if (_serverConnections.TryGetValue(connId, out EncryptedConnection con)) + { + ServerRemoveFromPending(con); + _serverConnections.Remove(connId); + } + OnServerDisconnected?.Invoke(connId); + } + + private void HandleInnerServerError(int connId, TransportError type, string msg) + { + OnServerError?.Invoke(connId, type, $"inner: {msg}"); + } + + private void HandleInnerServerDataReceived(int connId, ArraySegment data, int channel) + { + if (_serverConnections.TryGetValue(connId, out EncryptedConnection c)) + { + c.OnReceiveRaw(data, channel); + } + } + + private void HandleInnerServerConnected(int connId) + { + Debug.Log($"[EncryptionTransport] New connection #{connId}"); + EncryptedConnection ec = null; + ec = new EncryptedConnection( + _credentials, + false, + (segment, channel) => inner.ServerSend(connId, segment, channel), + (segment, channel) => OnServerDataReceived?.Invoke(connId, segment, channel), + () => + { + Debug.Log($"[EncryptionTransport] Connection #{connId} is ready"); + ServerRemoveFromPending(ec); + OnServerConnected?.Invoke(connId); + }, + (type, msg) => + { + OnServerError?.Invoke(connId, type, msg); + ServerDisconnect(connId); + }); + _serverConnections.Add(connId, ec); + _serverPendingConnections.Add(ec); + } + + private void HandleInnerClientDisconnected() + { + _client = null; + OnClientDisconnected?.Invoke(); + } + + private void HandleInnerClientError(TransportError arg1, string arg2) + { + OnClientError?.Invoke(arg1, $"inner: {arg2}"); + } + + private void HandleInnerClientDataReceived(ArraySegment data, int channel) + { + _client?.OnReceiveRaw(data, channel); + } + + private void HandleInnerClientConnected() + { + _client = new EncryptedConnection( + _credentials, + true, + (segment, channel) => inner.ClientSend(segment, channel), + (segment, channel) => OnClientDataReceived?.Invoke(segment, channel), + () => + { + OnClientConnected?.Invoke(); + }, + (type, msg) => + { + OnClientError?.Invoke(type, msg); + ClientDisconnect(); + }, + HandleClientValidateServerPubKey); + } + + private bool HandleClientValidateServerPubKey(PubKeyInfo pubKeyInfo) + { + switch (clientValidateServerPubKey) + { + case ValidationMode.Off: + return true; + case ValidationMode.List: + return Array.IndexOf(clientTrustedPubKeySignatures, pubKeyInfo.Fingerprint) >= 0; + case ValidationMode.Callback: + return onClientValidateServerPubKey(pubKeyInfo); + default: + throw new ArgumentOutOfRangeException(); + } + } + + public override bool Available() => inner.Available(); + + public override bool ClientConnected() => _client != null && _client.IsReady; + + public override void ClientConnect(string address) + { + switch (clientValidateServerPubKey) + { + case ValidationMode.Off: + break; + case ValidationMode.List: + if (clientTrustedPubKeySignatures == null || clientTrustedPubKeySignatures.Length == 0) + { + OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to List, but the clientTrustedPubKeySignatures list is empty."); + return; + } + break; + case ValidationMode.Callback: + if (onClientValidateServerPubKey == null) + { + OnClientError?.Invoke(TransportError.Unexpected, "Validate Server Public Key is set to Callback, but the onClientValidateServerPubKey handler is not set"); + return; + } + break; + default: + throw new ArgumentOutOfRangeException(); + } + _credentials = EncryptionCredentials.Generate(); + inner.OnClientConnected = HandleInnerClientConnected; + inner.OnClientDataReceived = HandleInnerClientDataReceived; + inner.OnClientDataSent = (bytes, channel) => OnClientDataSent?.Invoke(bytes, channel); + inner.OnClientError = HandleInnerClientError; + inner.OnClientDisconnected = HandleInnerClientDisconnected; + inner.ClientConnect(address); + } + + public override void ClientSend(ArraySegment segment, int channelId = Channels.Reliable) => + _client?.Send(segment, channelId); + + public override void ClientDisconnect() => inner.ClientDisconnect(); + + public override Uri ServerUri() => inner.ServerUri(); + + public override bool ServerActive() => inner.ServerActive(); + + public override void ServerStart() + { + if (serverLoadKeyPairFromFile) + { + _credentials = EncryptionCredentials.LoadFromFile(serverKeypairPath); + } + else + { + _credentials = EncryptionCredentials.Generate(); + } + inner.OnServerConnected = HandleInnerServerConnected; + inner.OnServerDataReceived = HandleInnerServerDataReceived; + inner.OnServerDataSent = (connId, bytes, channel) => OnServerDataSent?.Invoke(connId, bytes, channel); + inner.OnServerError = HandleInnerServerError; + inner.OnServerDisconnected = HandleInnerServerDisconnected; + inner.ServerStart(); + } + + public override void ServerSend(int connectionId, ArraySegment segment, int channelId = Channels.Reliable) + { + if (_serverConnections.TryGetValue(connectionId, out EncryptedConnection connection) && connection.IsReady) + { + connection.Send(segment, channelId); + } + } + + public override void ServerDisconnect(int connectionId) + { + // cleanup is done via inners disconnect event + inner.ServerDisconnect(connectionId); + } + + public override string ServerGetClientAddress(int connectionId) => inner.ServerGetClientAddress(connectionId); + + public override void ServerStop() => inner.ServerStop(); + + public override int GetMaxPacketSize(int channelId = Channels.Reliable) => + inner.GetMaxPacketSize(channelId) - EncryptedConnection.Overhead; + + public override void Shutdown() => inner.Shutdown(); + + public override void ClientEarlyUpdate() + { + inner.ClientEarlyUpdate(); + } + + public override void ClientLateUpdate() + { + inner.ClientLateUpdate(); + Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); + _client?.TickNonReady(NetworkTime.localTime); + Profiler.EndSample(); + } + + public override void ServerEarlyUpdate() + { + inner.ServerEarlyUpdate(); + } + + public override void ServerLateUpdate() + { + inner.ServerLateUpdate(); + Profiler.BeginSample("EncryptionTransport.ServerLateUpdate"); + // Reverse iteration as entries can be removed while updating + for (int i = _serverPendingConnections.Count - 1; i >= 0; i--) + { + _serverPendingConnections[i].TickNonReady(NetworkTime.time); + } + Profiler.EndSample(); + } + } + +} diff --git a/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs.meta b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs.meta new file mode 100644 index 000000000..98a36edd2 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/EncryptionTransport.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 0aa135acc32a4383ae9a5817f018cb06 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {fileID: 2800000, guid: 7453abfe9e8b2c04a8a47eb536fe21eb, type: 3} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirror/Transports/Encryption/PubKeyInfo.cs b/Assets/Mirror/Transports/Encryption/PubKeyInfo.cs new file mode 100644 index 000000000..d98906131 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/PubKeyInfo.cs @@ -0,0 +1,9 @@ +using System; +using Org.BouncyCastle.Crypto; + +public struct PubKeyInfo +{ + public string Fingerprint; + public ArraySegment Serialized; + public AsymmetricKeyParameter Key; +} diff --git a/Assets/Mirror/Transports/Encryption/PubKeyInfo.cs.meta b/Assets/Mirror/Transports/Encryption/PubKeyInfo.cs.meta new file mode 100644 index 000000000..7b824c217 --- /dev/null +++ b/Assets/Mirror/Transports/Encryption/PubKeyInfo.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e1e3744418024c02acf39f44c1d1bd20 +timeCreated: 1708874062 \ No newline at end of file