fix: #3280 #3083 #3217 MultiplexTransport connectionId multiplexing out of int.max range (#3291)

* lookup wip

* so far

* transport lookup

* wip

* okokok

* all done

* better

* x

* tests
This commit is contained in:
mischa 2022-12-04 13:57:51 +01:00 committed by GitHub
parent 5991a9a53a
commit 6d60471868
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 140 additions and 123 deletions

View File

@ -28,92 +28,55 @@ public void Setup()
public override void TearDown() => base.TearDown();
[Test]
public void MultiplexConnectionId()
public void ConnectionIdMapping()
{
// if we have 3 transports, then
// transport 0 will produce connection ids [0, 3, 6, 9, ...]
const int transportAmount = 3;
// add a few connectionIds from transport #0
// one large connId to prevent https://github.com/vis2k/Mirror/issues/3280
int t0_c10 = transport.AddToLookup(10, 0); // should get multiplexId = 1
int t0_c20 = transport.AddToLookup(20, 0); // should get multiplexId = 2
int t0_cmax = transport.AddToLookup(int.MaxValue, 0); // should get multiplexId = 3
Assert.That(MultiplexTransport.MultiplexConnectionId(0, 0, transportAmount), Is.EqualTo(0));
Assert.That(MultiplexTransport.MultiplexConnectionId(1, 0, transportAmount), Is.EqualTo(3));
Assert.That(MultiplexTransport.MultiplexConnectionId(2, 0, transportAmount), Is.EqualTo(6));
Assert.That(MultiplexTransport.MultiplexConnectionId(3, 0, transportAmount), Is.EqualTo(9));
// add a few connectionIds from transport #1
// one large connId to prevent https://github.com/vis2k/Mirror/issues/3280
int t1_c10 = transport.AddToLookup(10, 1); // should get multiplexId = 4
int t1_c50 = transport.AddToLookup(50, 1); // should get multiplexId = 5
int t1_cmax = transport.AddToLookup(int.MaxValue, 1); // should get multiplexId = 6
// transport 1 will produce connection ids [1, 4, 7, 10, ...]
Assert.That(MultiplexTransport.MultiplexConnectionId(0, 1, transportAmount), Is.EqualTo(1));
Assert.That(MultiplexTransport.MultiplexConnectionId(1, 1, transportAmount), Is.EqualTo(4));
Assert.That(MultiplexTransport.MultiplexConnectionId(2, 1, transportAmount), Is.EqualTo(7));
Assert.That(MultiplexTransport.MultiplexConnectionId(3, 1, transportAmount), Is.EqualTo(10));
// MultiplexId -> (OriginalId, TransportIndex) for transport #0
transport.OriginalId(t0_c10, out int originalId, out int transportIndex);
Assert.That(transportIndex, Is.EqualTo(0));
Assert.That(originalId, Is.EqualTo(10));
// transport 2 will produce connection ids [2, 5, 8, 11, ...]
Assert.That(MultiplexTransport.MultiplexConnectionId(0, 2, transportAmount), Is.EqualTo(2));
Assert.That(MultiplexTransport.MultiplexConnectionId(1, 2, transportAmount), Is.EqualTo(5));
Assert.That(MultiplexTransport.MultiplexConnectionId(2, 2, transportAmount), Is.EqualTo(8));
Assert.That(MultiplexTransport.MultiplexConnectionId(3, 2, transportAmount), Is.EqualTo(11));
}
transport.OriginalId(t0_c20, out originalId, out transportIndex);
Assert.That(transportIndex, Is.EqualTo(0));
Assert.That(originalId, Is.EqualTo(20));
[Test]
public void OriginalConnectionId()
{
const int transportAmount = 3;
transport.OriginalId(t0_cmax, out originalId, out transportIndex);
Assert.That(transportIndex, Is.EqualTo(0));
Assert.That(originalId, Is.EqualTo(int.MaxValue));
Assert.That(MultiplexTransport.OriginalConnectionId(0, transportAmount), Is.EqualTo(0));
Assert.That(MultiplexTransport.OriginalConnectionId(1, transportAmount), Is.EqualTo(0));
Assert.That(MultiplexTransport.OriginalConnectionId(2, transportAmount), Is.EqualTo(0));
// MultiplexId -> (OriginalId, TransportIndex) for transport #1
transport.OriginalId(t1_c10, out originalId, out transportIndex);
Assert.That(transportIndex, Is.EqualTo(1));
Assert.That(originalId, Is.EqualTo(10));
Assert.That(MultiplexTransport.OriginalConnectionId(3, transportAmount), Is.EqualTo(1));
Assert.That(MultiplexTransport.OriginalConnectionId(4, transportAmount), Is.EqualTo(1));
Assert.That(MultiplexTransport.OriginalConnectionId(5, transportAmount), Is.EqualTo(1));
transport.OriginalId(t1_c50, out originalId, out transportIndex);
Assert.That(transportIndex, Is.EqualTo(1));
Assert.That(originalId, Is.EqualTo(50));
Assert.That(MultiplexTransport.OriginalConnectionId(6, transportAmount), Is.EqualTo(2));
Assert.That(MultiplexTransport.OriginalConnectionId(7, transportAmount), Is.EqualTo(2));
Assert.That(MultiplexTransport.OriginalConnectionId(8, transportAmount), Is.EqualTo(2));
transport.OriginalId(t1_cmax, out originalId, out transportIndex);
Assert.That(transportIndex, Is.EqualTo(1));
Assert.That(originalId, Is.EqualTo(int.MaxValue));
Assert.That(MultiplexTransport.OriginalConnectionId(9, transportAmount), Is.EqualTo(3));
}
// (OriginalId, TransportIndex) -> MultiplexId for transport #1
Assert.That(transport.MultiplexId(10, 0), Is.EqualTo(t0_c10));
Assert.That(transport.MultiplexId(20, 0), Is.EqualTo(t0_c20));
Assert.That(transport.MultiplexId(int.MaxValue, 0), Is.EqualTo(t0_cmax));
[Test]
public void OriginalTransportId()
{
const int transportAmount = 3;
Assert.That(MultiplexTransport.OriginalTransportId(0, transportAmount), Is.EqualTo(0));
Assert.That(MultiplexTransport.OriginalTransportId(1, transportAmount), Is.EqualTo(1));
Assert.That(MultiplexTransport.OriginalTransportId(2, transportAmount), Is.EqualTo(2));
Assert.That(MultiplexTransport.OriginalTransportId(3, transportAmount), Is.EqualTo(0));
Assert.That(MultiplexTransport.OriginalTransportId(4, transportAmount), Is.EqualTo(1));
Assert.That(MultiplexTransport.OriginalTransportId(5, transportAmount), Is.EqualTo(2));
Assert.That(MultiplexTransport.OriginalTransportId(6, transportAmount), Is.EqualTo(0));
Assert.That(MultiplexTransport.OriginalTransportId(7, transportAmount), Is.EqualTo(1));
Assert.That(MultiplexTransport.OriginalTransportId(8, transportAmount), Is.EqualTo(2));
Assert.That(MultiplexTransport.OriginalTransportId(9, transportAmount), Is.EqualTo(0));
}
// test to reproduce https://github.com/vis2k/Mirror/issues/3280
[Test]
public void LargeConnectionId()
{
const int transportAmount = 3;
// let's say transport #2 gives us a very large connectionId.
// for example, KCP may use GetHashCode() as connectionId.
// 2147483647 - 10 = 2147483637
const int largeId = int.MaxValue - 10;
const int transportId = 2;
// connectionId * transportAmount + transportId
// = 2147483637 * 3 + 2
// = 6442450913
// which does not fit into int.max
int multiplexedId = MultiplexTransport.MultiplexConnectionId(largeId, transportId, transportAmount);
// Assert.That(multiplexedId, Is.EqualTo(6442450913)); not equal!
// convert it back. multiplexed isn't correct, so neither will this be
int originalId = MultiplexTransport.OriginalConnectionId(multiplexedId, transportAmount);
Assert.That(originalId, Is.EqualTo(largeId));
// (OriginalId, TransportIndex) -> MultiplexId for transport #2
Assert.That(transport.MultiplexId(10, 1), Is.EqualTo(t1_c10));
Assert.That(transport.MultiplexId(50, 1), Is.EqualTo(t1_c50));
Assert.That(transport.MultiplexId(int.MaxValue, 1), Is.EqualTo(t1_cmax));
}
[Test]
@ -281,30 +244,31 @@ void SendMessage(int connectionId)
public void TestServerSend()
{
transport1.Available().Returns(true);
transport2.Available().Returns(true);
transport.ServerStart();
transport.ClientConnect("some.server.com");
transport.OnServerConnected = _ => {};
transport.OnServerDisconnected = _ => {};
// connect two connectionIds.
// one of them very large to prevent
// https://github.com/vis2k/Mirror/issues/3280
transport1.OnServerConnected(10);
transport2.OnServerConnected(int.MaxValue);
byte[] data = { 1, 2, 3 };
ArraySegment<byte> segment = new ArraySegment<byte>(data);
// call ServerSend on multiplex transport.
// multiplexed connId = 0 corresponds to connId = 0 with transport #1
transport.ServerSend(0, data, 0);
transport1.Received().ServerSend(0, segment, 0);
// call ServerSend on multiplex transport.
// multiplexed connId = 1 corresponds to connId = 0 with transport #2
// multiplexed connId = 1 represents transport #1 connId = 10
transport.ServerSend(1, data, 0);
transport2.Received().ServerSend(0, segment, 0);
transport1.Received().ServerSend(10, segment, 0);
// call ServerSend on multiplex transport.
// multiplexed connId = 2 corresponds to connId = 1 with transport #1
// multiplexed connId = 2 represents transport #2 connId = int.max
transport.ServerSend(2, data, 0);
transport1.Received().ServerSend(1, segment, 0);
// call ServerSend on multiplex transport.
// multiplexed connId = 3 corresponds to connId = 1 with transport #2
transport.ServerSend(3, data, 0);
transport2.Received().ServerSend(1, segment, 0);
transport2.Received().ServerSend(int.MaxValue, segment, 0);
}
}
}

View File

@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Text;
using UnityEngine;
@ -12,23 +13,72 @@ public class MultiplexTransport : Transport
Transport available;
// connection ids get mapped to base transports
// underlying transport connectionId to multiplexed connectionId lookup.
//
// originally we used a formula to map the connectionId:
// connectionId * transportAmount + transportId
//
// if we have 3 transports, then
// transport 0 will produce connection ids [0, 3, 6, 9, ...]
// transport 1 will produce connection ids [1, 4, 7, 10, ...]
// transport 2 will produce connection ids [2, 5, 8, 11, ...]
// transport 0 will produce connection ids [0, 3, 6, 9, ...]
// transport 1 will produce connection ids [1, 4, 7, 10, ...]
// transport 2 will produce connection ids [2, 5, 8, 11, ...]
//
// however, some transports like kcp may give very large connectionIds.
// if they are near int.max, then "* transprotAmount + transportIndex"
// will overflow, resulting in connIds which can't be projected back.
// https://github.com/vis2k/Mirror/issues/3280
//
// instead, use a simple lookup with 0-indexed ids.
// with initial capacity to avoid runtime allocations.
// convert original transport connId to multiplexed connId
public static int MultiplexConnectionId(int connectionId, int transportId, int transportAmount) =>
connectionId * transportAmount + transportId;
// (original connectionId, transport#) to multiplexed connectionId
readonly Dictionary<KeyValuePair<int, int>, int> originalToMultiplexedId =
new Dictionary<KeyValuePair<int, int>, int>(100);
// convert multiplexed connectionId back to original transport connId
public static int OriginalConnectionId(int multiplexConnectionId, int transportAmount) =>
multiplexConnectionId / transportAmount;
// multiplexed connectionId to (original connectionId, transport#)
readonly Dictionary<int, KeyValuePair<int, int>> multiplexedToOriginalId =
new Dictionary<int, KeyValuePair<int, int>>(100);
// convert multiplexed connectionId back to original transportId
public static int OriginalTransportId(int multiplexConnectionId, int transportAmount) =>
multiplexConnectionId % transportAmount;
// next multiplexed id counter. start at 1 because 0 is reserved for host.
int nextMultiplexedId = 1;
// add to bidirection lookup. returns the multiplexed connectionId.
public int AddToLookup(int originalConnectionId, int transportIndex)
{
// add to both
KeyValuePair<int, int> pair = new KeyValuePair<int, int>(originalConnectionId, transportIndex);
int multiplexedId = nextMultiplexedId++;
originalToMultiplexedId[pair] = multiplexedId;
multiplexedToOriginalId[multiplexedId] = pair;
return multiplexedId;
}
public void RemoveFromLookup(int originalConnectionId, int transportIndex)
{
// remove from both
KeyValuePair<int, int> pair = new KeyValuePair<int, int>(originalConnectionId, transportIndex);
int multiplexedId = originalToMultiplexedId[pair];
originalToMultiplexedId.Remove(pair);
multiplexedToOriginalId.Remove(multiplexedId);
}
public void OriginalId(int multiplexId, out int originalConnectionId, out int transportIndex)
{
KeyValuePair<int, int> pair = multiplexedToOriginalId[multiplexId];
originalConnectionId = pair.Key;
transportIndex = pair.Value;
}
public int MultiplexId(int originalConnectionId, int transportIndex)
{
KeyValuePair<int, int> pair = new KeyValuePair<int, int>(originalConnectionId, transportIndex);
return originalToMultiplexedId[pair];
}
////////////////////////////////////////////////////////////////////////
public void Awake()
{
@ -155,30 +205,36 @@ void AddServerCallbacks()
{
// this is required for the handlers, if I use i directly
// then all the handlers will use the last i
int locali = i;
int transportIndex = i;
Transport transport = transports[i];
transport.OnServerConnected = (baseConnectionId =>
transport.OnServerConnected = (originalConnectionId =>
{
// invoke Multiplex event with multiplexed connectionId
OnServerConnected.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length));
int multiplexedId = AddToLookup(originalConnectionId, transportIndex);
OnServerConnected.Invoke(multiplexedId);
});
transport.OnServerDataReceived = (baseConnectionId, data, channel) =>
transport.OnServerDataReceived = (originalConnectionId, data, channel) =>
{
// invoke Multiplex event with multiplexed connectionId
OnServerDataReceived.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length), data, channel);
int multiplexedId = MultiplexId(originalConnectionId, transportIndex);
OnServerDataReceived.Invoke(multiplexedId, data, channel);
};
transport.OnServerError = (baseConnectionId, error, reason) =>
transport.OnServerError = (originalConnectionId, error, reason) =>
{
// invoke Multiplex event with multiplexed connectionId
OnServerError.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length), error, reason);
int multiplexedId = MultiplexId(originalConnectionId, transportIndex);
OnServerError.Invoke(multiplexedId, error, reason);
};
transport.OnServerDisconnected = baseConnectionId =>
transport.OnServerDisconnected = originalConnectionId =>
{
// invoke Multiplex event with multiplexed connectionId
OnServerDisconnected.Invoke(MultiplexConnectionId(baseConnectionId, locali, transports.Length));
int multiplexedId = MultiplexId(originalConnectionId, transportIndex);
OnServerDisconnected.Invoke(multiplexedId);
RemoveFromLookup(originalConnectionId, transportIndex);
};
}
}
@ -200,26 +256,23 @@ public override bool ServerActive()
public override string ServerGetClientAddress(int connectionId)
{
// convert multiplexed connectionId to original transport + connId
int baseConnectionId = OriginalConnectionId(connectionId, transports.Length);
int transportId = OriginalTransportId(connectionId, transports.Length);
return transports[transportId].ServerGetClientAddress(baseConnectionId);
// convert multiplexed connectionId to original id & transport index
OriginalId(connectionId, out int originalConnectionId, out int transportIndex);
return transports[transportIndex].ServerGetClientAddress(originalConnectionId);
}
public override void ServerDisconnect(int connectionId)
{
// convert multiplexed connectionId to original transport + connId
int baseConnectionId = OriginalConnectionId(connectionId, transports.Length);
int transportId = OriginalTransportId(connectionId, transports.Length);
transports[transportId].ServerDisconnect(baseConnectionId);
// convert multiplexed connectionId to original id & transport index
OriginalId(connectionId, out int originalConnectionId, out int transportIndex);
transports[transportIndex].ServerDisconnect(originalConnectionId);
}
public override void ServerSend(int connectionId, ArraySegment<byte> segment, int channelId)
{
// convert multiplexed connectionId to original transport + connId
int baseConnectionId = OriginalConnectionId(connectionId, transports.Length);
int transportId = OriginalTransportId(connectionId, transports.Length);
transports[transportId].ServerSend(baseConnectionId, segment, channelId);
OriginalId(connectionId, out int originalConnectionId, out int transportIndex);
transports[transportIndex].ServerSend(originalConnectionId, segment, channelId);
}
public override void ServerStart()