From 6d60471868a83591ae8634898d40b17be832e525 Mon Sep 17 00:00:00 2001 From: mischa <16416509+vis2k@users.noreply.github.com> Date: Sun, 4 Dec 2022 13:57:51 +0100 Subject: [PATCH] fix: #3280 #3083 #3217 MultiplexTransport connectionId multiplexing out of int.max range (#3291) * lookup wip * so far * transport lookup * wip * okokok * all done * better * x * tests --- .../Tests/Editor/MultiplexTransportTest.cs | 144 +++++++----------- .../Multiplex/MultiplexTransport.cs | 119 +++++++++++---- 2 files changed, 140 insertions(+), 123 deletions(-) diff --git a/Assets/Mirror/Tests/Editor/MultiplexTransportTest.cs b/Assets/Mirror/Tests/Editor/MultiplexTransportTest.cs index 3c2b9a49e..9d5455419 100644 --- a/Assets/Mirror/Tests/Editor/MultiplexTransportTest.cs +++ b/Assets/Mirror/Tests/Editor/MultiplexTransportTest.cs @@ -28,92 +28,55 @@ public void Setup() public override void TearDown() => base.TearDown(); [Test] - public void MultiplexConnectionId() + public void ConnectionIdMapping() { - // if we have 3 transports, then - // transport 0 will produce connection ids [0, 3, 6, 9, ...] - const int transportAmount = 3; + // add a few connectionIds from transport #0 + // one large connId to prevent https://github.com/vis2k/Mirror/issues/3280 + int t0_c10 = transport.AddToLookup(10, 0); // should get multiplexId = 1 + int t0_c20 = transport.AddToLookup(20, 0); // should get multiplexId = 2 + int t0_cmax = transport.AddToLookup(int.MaxValue, 0); // should get multiplexId = 3 - Assert.That(MultiplexTransport.MultiplexConnectionId(0, 0, transportAmount), Is.EqualTo(0)); - Assert.That(MultiplexTransport.MultiplexConnectionId(1, 0, transportAmount), Is.EqualTo(3)); - Assert.That(MultiplexTransport.MultiplexConnectionId(2, 0, transportAmount), Is.EqualTo(6)); - Assert.That(MultiplexTransport.MultiplexConnectionId(3, 0, transportAmount), Is.EqualTo(9)); + // add a few connectionIds from transport #1 + // one large connId to prevent https://github.com/vis2k/Mirror/issues/3280 + int t1_c10 = transport.AddToLookup(10, 1); // should get multiplexId = 4 + int t1_c50 = transport.AddToLookup(50, 1); // should get multiplexId = 5 + int t1_cmax = transport.AddToLookup(int.MaxValue, 1); // should get multiplexId = 6 - // transport 1 will produce connection ids [1, 4, 7, 10, ...] - Assert.That(MultiplexTransport.MultiplexConnectionId(0, 1, transportAmount), Is.EqualTo(1)); - Assert.That(MultiplexTransport.MultiplexConnectionId(1, 1, transportAmount), Is.EqualTo(4)); - Assert.That(MultiplexTransport.MultiplexConnectionId(2, 1, transportAmount), Is.EqualTo(7)); - Assert.That(MultiplexTransport.MultiplexConnectionId(3, 1, transportAmount), Is.EqualTo(10)); + // MultiplexId -> (OriginalId, TransportIndex) for transport #0 + transport.OriginalId(t0_c10, out int originalId, out int transportIndex); + Assert.That(transportIndex, Is.EqualTo(0)); + Assert.That(originalId, Is.EqualTo(10)); - // transport 2 will produce connection ids [2, 5, 8, 11, ...] - Assert.That(MultiplexTransport.MultiplexConnectionId(0, 2, transportAmount), Is.EqualTo(2)); - Assert.That(MultiplexTransport.MultiplexConnectionId(1, 2, transportAmount), Is.EqualTo(5)); - Assert.That(MultiplexTransport.MultiplexConnectionId(2, 2, transportAmount), Is.EqualTo(8)); - Assert.That(MultiplexTransport.MultiplexConnectionId(3, 2, transportAmount), Is.EqualTo(11)); - } + transport.OriginalId(t0_c20, out originalId, out transportIndex); + Assert.That(transportIndex, Is.EqualTo(0)); + Assert.That(originalId, Is.EqualTo(20)); - [Test] - public void OriginalConnectionId() - { - const int transportAmount = 3; + transport.OriginalId(t0_cmax, out originalId, out transportIndex); + Assert.That(transportIndex, Is.EqualTo(0)); + Assert.That(originalId, Is.EqualTo(int.MaxValue)); - Assert.That(MultiplexTransport.OriginalConnectionId(0, transportAmount), Is.EqualTo(0)); - Assert.That(MultiplexTransport.OriginalConnectionId(1, transportAmount), Is.EqualTo(0)); - Assert.That(MultiplexTransport.OriginalConnectionId(2, transportAmount), Is.EqualTo(0)); + // MultiplexId -> (OriginalId, TransportIndex) for transport #1 + transport.OriginalId(t1_c10, out originalId, out transportIndex); + Assert.That(transportIndex, Is.EqualTo(1)); + Assert.That(originalId, Is.EqualTo(10)); - Assert.That(MultiplexTransport.OriginalConnectionId(3, transportAmount), Is.EqualTo(1)); - Assert.That(MultiplexTransport.OriginalConnectionId(4, transportAmount), Is.EqualTo(1)); - Assert.That(MultiplexTransport.OriginalConnectionId(5, transportAmount), Is.EqualTo(1)); + transport.OriginalId(t1_c50, out originalId, out transportIndex); + Assert.That(transportIndex, Is.EqualTo(1)); + Assert.That(originalId, Is.EqualTo(50)); - Assert.That(MultiplexTransport.OriginalConnectionId(6, transportAmount), Is.EqualTo(2)); - Assert.That(MultiplexTransport.OriginalConnectionId(7, transportAmount), Is.EqualTo(2)); - Assert.That(MultiplexTransport.OriginalConnectionId(8, transportAmount), Is.EqualTo(2)); + transport.OriginalId(t1_cmax, out originalId, out transportIndex); + Assert.That(transportIndex, Is.EqualTo(1)); + Assert.That(originalId, Is.EqualTo(int.MaxValue)); - Assert.That(MultiplexTransport.OriginalConnectionId(9, transportAmount), Is.EqualTo(3)); - } + // (OriginalId, TransportIndex) -> MultiplexId for transport #1 + Assert.That(transport.MultiplexId(10, 0), Is.EqualTo(t0_c10)); + Assert.That(transport.MultiplexId(20, 0), Is.EqualTo(t0_c20)); + Assert.That(transport.MultiplexId(int.MaxValue, 0), Is.EqualTo(t0_cmax)); - [Test] - public void OriginalTransportId() - { - const int transportAmount = 3; - - Assert.That(MultiplexTransport.OriginalTransportId(0, transportAmount), Is.EqualTo(0)); - Assert.That(MultiplexTransport.OriginalTransportId(1, transportAmount), Is.EqualTo(1)); - Assert.That(MultiplexTransport.OriginalTransportId(2, transportAmount), Is.EqualTo(2)); - - Assert.That(MultiplexTransport.OriginalTransportId(3, transportAmount), Is.EqualTo(0)); - Assert.That(MultiplexTransport.OriginalTransportId(4, transportAmount), Is.EqualTo(1)); - Assert.That(MultiplexTransport.OriginalTransportId(5, transportAmount), Is.EqualTo(2)); - - Assert.That(MultiplexTransport.OriginalTransportId(6, transportAmount), Is.EqualTo(0)); - Assert.That(MultiplexTransport.OriginalTransportId(7, transportAmount), Is.EqualTo(1)); - Assert.That(MultiplexTransport.OriginalTransportId(8, transportAmount), Is.EqualTo(2)); - - Assert.That(MultiplexTransport.OriginalTransportId(9, transportAmount), Is.EqualTo(0)); - } - - // test to reproduce https://github.com/vis2k/Mirror/issues/3280 - [Test] - public void LargeConnectionId() - { - const int transportAmount = 3; - - // let's say transport #2 gives us a very large connectionId. - // for example, KCP may use GetHashCode() as connectionId. - // 2147483647 - 10 = 2147483637 - const int largeId = int.MaxValue - 10; - const int transportId = 2; - - // connectionId * transportAmount + transportId - // = 2147483637 * 3 + 2 - // = 6442450913 - // which does not fit into int.max - int multiplexedId = MultiplexTransport.MultiplexConnectionId(largeId, transportId, transportAmount); - // Assert.That(multiplexedId, Is.EqualTo(6442450913)); not equal! - - // convert it back. multiplexed isn't correct, so neither will this be - int originalId = MultiplexTransport.OriginalConnectionId(multiplexedId, transportAmount); - Assert.That(originalId, Is.EqualTo(largeId)); + // (OriginalId, TransportIndex) -> MultiplexId for transport #2 + Assert.That(transport.MultiplexId(10, 1), Is.EqualTo(t1_c10)); + Assert.That(transport.MultiplexId(50, 1), Is.EqualTo(t1_c50)); + Assert.That(transport.MultiplexId(int.MaxValue, 1), Is.EqualTo(t1_cmax)); } [Test] @@ -281,30 +244,31 @@ void SendMessage(int connectionId) public void TestServerSend() { transport1.Available().Returns(true); + transport2.Available().Returns(true); + transport.ServerStart(); transport.ClientConnect("some.server.com"); + transport.OnServerConnected = _ => {}; + transport.OnServerDisconnected = _ => {}; + + // connect two connectionIds. + // one of them very large to prevent + // https://github.com/vis2k/Mirror/issues/3280 + transport1.OnServerConnected(10); + transport2.OnServerConnected(int.MaxValue); + byte[] data = { 1, 2, 3 }; ArraySegment segment = new ArraySegment(data); // call ServerSend on multiplex transport. - // multiplexed connId = 0 corresponds to connId = 0 with transport #1 - transport.ServerSend(0, data, 0); - transport1.Received().ServerSend(0, segment, 0); - - // call ServerSend on multiplex transport. - // multiplexed connId = 1 corresponds to connId = 0 with transport #2 + // multiplexed connId = 1 represents transport #1 connId = 10 transport.ServerSend(1, data, 0); - transport2.Received().ServerSend(0, segment, 0); + transport1.Received().ServerSend(10, segment, 0); // call ServerSend on multiplex transport. - // multiplexed connId = 2 corresponds to connId = 1 with transport #1 + // multiplexed connId = 2 represents transport #2 connId = int.max transport.ServerSend(2, data, 0); - transport1.Received().ServerSend(1, segment, 0); - - // call ServerSend on multiplex transport. - // multiplexed connId = 3 corresponds to connId = 1 with transport #2 - transport.ServerSend(3, data, 0); - transport2.Received().ServerSend(1, segment, 0); + transport2.Received().ServerSend(int.MaxValue, segment, 0); } } } diff --git a/Assets/Mirror/Transports/Multiplex/MultiplexTransport.cs b/Assets/Mirror/Transports/Multiplex/MultiplexTransport.cs index 52b65dcc3..4f31df490 100644 --- a/Assets/Mirror/Transports/Multiplex/MultiplexTransport.cs +++ b/Assets/Mirror/Transports/Multiplex/MultiplexTransport.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Text; using UnityEngine; @@ -12,23 +13,72 @@ public class MultiplexTransport : Transport Transport available; - // connection ids get mapped to base transports + // underlying transport connectionId to multiplexed connectionId lookup. + // + // originally we used a formula to map the connectionId: + // connectionId * transportAmount + transportId + // // if we have 3 transports, then - // transport 0 will produce connection ids [0, 3, 6, 9, ...] - // transport 1 will produce connection ids [1, 4, 7, 10, ...] - // transport 2 will produce connection ids [2, 5, 8, 11, ...] + // transport 0 will produce connection ids [0, 3, 6, 9, ...] + // transport 1 will produce connection ids [1, 4, 7, 10, ...] + // transport 2 will produce connection ids [2, 5, 8, 11, ...] + // + // however, some transports like kcp may give very large connectionIds. + // if they are near int.max, then "* transprotAmount + transportIndex" + // will overflow, resulting in connIds which can't be projected back. + // https://github.com/vis2k/Mirror/issues/3280 + // + // instead, use a simple lookup with 0-indexed ids. + // with initial capacity to avoid runtime allocations. - // convert original transport connId to multiplexed connId - public static int MultiplexConnectionId(int connectionId, int transportId, int transportAmount) => - connectionId * transportAmount + transportId; + // (original connectionId, transport#) to multiplexed connectionId + readonly Dictionary, int> originalToMultiplexedId = + new Dictionary, int>(100); - // convert multiplexed connectionId back to original transport connId - public static int OriginalConnectionId(int multiplexConnectionId, int transportAmount) => - multiplexConnectionId / transportAmount; + // multiplexed connectionId to (original connectionId, transport#) + readonly Dictionary> multiplexedToOriginalId = + new Dictionary>(100); - // convert multiplexed connectionId back to original transportId - public static int OriginalTransportId(int multiplexConnectionId, int transportAmount) => - multiplexConnectionId % transportAmount; + // next multiplexed id counter. start at 1 because 0 is reserved for host. + int nextMultiplexedId = 1; + + // add to bidirection lookup. returns the multiplexed connectionId. + public int AddToLookup(int originalConnectionId, int transportIndex) + { + // add to both + KeyValuePair pair = new KeyValuePair(originalConnectionId, transportIndex); + int multiplexedId = nextMultiplexedId++; + + originalToMultiplexedId[pair] = multiplexedId; + multiplexedToOriginalId[multiplexedId] = pair; + + return multiplexedId; + } + + public void RemoveFromLookup(int originalConnectionId, int transportIndex) + { + // remove from both + KeyValuePair pair = new KeyValuePair(originalConnectionId, transportIndex); + int multiplexedId = originalToMultiplexedId[pair]; + + originalToMultiplexedId.Remove(pair); + multiplexedToOriginalId.Remove(multiplexedId); + } + + public void OriginalId(int multiplexId, out int originalConnectionId, out int transportIndex) + { + KeyValuePair pair = multiplexedToOriginalId[multiplexId]; + originalConnectionId = pair.Key; + transportIndex = pair.Value; + } + + public int MultiplexId(int originalConnectionId, int transportIndex) + { + KeyValuePair pair = new KeyValuePair(originalConnectionId, transportIndex); + return originalToMultiplexedId[pair]; + } + + //////////////////////////////////////////////////////////////////////// public void Awake() { @@ -155,30 +205,36 @@ void AddServerCallbacks() { // this is required for the handlers, if I use i directly // then all the handlers will use the last i - int locali = i; + int transportIndex = i; Transport transport = transports[i]; - transport.OnServerConnected = (baseConnectionId => + transport.OnServerConnected = (originalConnectionId => { // invoke Multiplex event with multiplexed connectionId - OnServerConnected.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length)); + int multiplexedId = AddToLookup(originalConnectionId, transportIndex); + OnServerConnected.Invoke(multiplexedId); }); - transport.OnServerDataReceived = (baseConnectionId, data, channel) => + transport.OnServerDataReceived = (originalConnectionId, data, channel) => { // invoke Multiplex event with multiplexed connectionId - OnServerDataReceived.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length), data, channel); + int multiplexedId = MultiplexId(originalConnectionId, transportIndex); + OnServerDataReceived.Invoke(multiplexedId, data, channel); }; - transport.OnServerError = (baseConnectionId, error, reason) => + transport.OnServerError = (originalConnectionId, error, reason) => { // invoke Multiplex event with multiplexed connectionId - OnServerError.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length), error, reason); + int multiplexedId = MultiplexId(originalConnectionId, transportIndex); + OnServerError.Invoke(multiplexedId, error, reason); }; - transport.OnServerDisconnected = baseConnectionId => + + transport.OnServerDisconnected = originalConnectionId => { // invoke Multiplex event with multiplexed connectionId - OnServerDisconnected.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length)); + int multiplexedId = MultiplexId(originalConnectionId, transportIndex); + OnServerDisconnected.Invoke(multiplexedId); + RemoveFromLookup(originalConnectionId, transportIndex); }; } } @@ -200,26 +256,23 @@ public override bool ServerActive() public override string ServerGetClientAddress(int connectionId) { - // convert multiplexed connectionId to original transport + connId - int baseConnectionId = OriginalConnectionId(connectionId, transports.Length); - int transportId = OriginalTransportId(connectionId, transports.Length); - return transports[transportId].ServerGetClientAddress(baseConnectionId); + // convert multiplexed connectionId to original id & transport index + OriginalId(connectionId, out int originalConnectionId, out int transportIndex); + return transports[transportIndex].ServerGetClientAddress(originalConnectionId); } public override void ServerDisconnect(int connectionId) { - // convert multiplexed connectionId to original transport + connId - int baseConnectionId = OriginalConnectionId(connectionId, transports.Length); - int transportId = OriginalTransportId(connectionId, transports.Length); - transports[transportId].ServerDisconnect(baseConnectionId); + // convert multiplexed connectionId to original id & transport index + OriginalId(connectionId, out int originalConnectionId, out int transportIndex); + transports[transportIndex].ServerDisconnect(originalConnectionId); } public override void ServerSend(int connectionId, ArraySegment segment, int channelId) { // convert multiplexed connectionId to original transport + connId - int baseConnectionId = OriginalConnectionId(connectionId, transports.Length); - int transportId = OriginalTransportId(connectionId, transports.Length); - transports[transportId].ServerSend(baseConnectionId, segment, channelId); + OriginalId(connectionId, out int originalConnectionId, out int transportIndex); + transports[transportIndex].ServerSend(originalConnectionId, segment, channelId); } public override void ServerStart()