feat(Transport): Added OnServerConnectedWithAddress (#3855)

* feat(Transport): Added OnServerConnectedWithAddress
Transports can now pass the remote client address directly to NetworkServer Action
- Facliltates ThreadedTransport passing the client Address
- KCP Transport Updated accordingly
- Original OnServerConnected passes ServerGetClientAddress result to NetworkConnectionToClient consrtuctor
- Saves round trips back to the transport for client address whenever it's needed

* formatting

* Simplified

* Cleanup
This commit is contained in:
MrGadget 2024-07-05 05:58:43 -04:00 committed by GitHub
parent b0aa8a412a
commit 2d2e270868
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 52 additions and 33 deletions

View File

@ -14,8 +14,6 @@ public class LocalConnectionToClient : NetworkConnectionToClient
public LocalConnectionToClient() : base(LocalConnectionId) {} public LocalConnectionToClient() : base(LocalConnectionId) {}
public override string address => "localhost";
internal override void Send(ArraySegment<byte> segment, int channelId = Channels.Reliable) internal override void Send(ArraySegment<byte> segment, int channelId = Channels.Reliable)
{ {
// instead of invoking it directly, we enqueue and process next update. // instead of invoking it directly, we enqueue and process next update.

View File

@ -14,7 +14,7 @@ public class NetworkConnectionToClient : NetworkConnection
readonly NetworkWriter reliableRpcs = new NetworkWriter(); readonly NetworkWriter reliableRpcs = new NetworkWriter();
readonly NetworkWriter unreliableRpcs = new NetworkWriter(); readonly NetworkWriter unreliableRpcs = new NetworkWriter();
public virtual string address => Transport.active.ServerGetClientAddress(connectionId); public virtual string address { get; private set; }
/// <summary>NetworkIdentities that this connection can see</summary> /// <summary>NetworkIdentities that this connection can see</summary>
// TODO move to server's NetworkConnectionToClient? // TODO move to server's NetworkConnectionToClient?
@ -50,9 +50,11 @@ public class NetworkConnectionToClient : NetworkConnection
/// <summary>Round trip time (in seconds) that it takes a message to go server->client->server.</summary> /// <summary>Round trip time (in seconds) that it takes a message to go server->client->server.</summary>
public double rtt => _rtt.Value; public double rtt => _rtt.Value;
public NetworkConnectionToClient(int networkConnectionId) public NetworkConnectionToClient(int networkConnectionId, string clientAddress = "localhost")
: base(networkConnectionId) : base(networkConnectionId)
{ {
address = clientAddress;
// initialize EMA with 'emaDuration' seconds worth of history. // initialize EMA with 'emaDuration' seconds worth of history.
// 1 second holds 'sendRate' worth of values. // 1 second holds 'sendRate' worth of values.
// multiplied by emaDuration gives n-seconds. // multiplied by emaDuration gives n-seconds.

View File

@ -192,6 +192,7 @@ static void AddTransportHandlers()
{ {
// += so that other systems can also hook into it (i.e. statistics) // += so that other systems can also hook into it (i.e. statistics)
Transport.active.OnServerConnected += OnTransportConnected; Transport.active.OnServerConnected += OnTransportConnected;
Transport.active.OnServerConnectedWithAddress += OnTransportConnectedWithAddress;
Transport.active.OnServerDataReceived += OnTransportData; Transport.active.OnServerDataReceived += OnTransportData;
Transport.active.OnServerDisconnected += OnTransportDisconnected; Transport.active.OnServerDisconnected += OnTransportDisconnected;
Transport.active.OnServerError += OnTransportError; Transport.active.OnServerError += OnTransportError;
@ -636,25 +637,39 @@ public static void SendToReadyObservers<T>(NetworkIdentity identity, T message,
// transport events //////////////////////////////////////////////////// // transport events ////////////////////////////////////////////////////
// called by transport // called by transport
static void OnTransportConnected(int connectionId) static void OnTransportConnected(int connectionId)
{ => OnTransportConnectedWithAddress(connectionId, Transport.active.ServerGetClientAddress(connectionId));
// Debug.Log($"Server accepted client:{connectionId}");
static void OnTransportConnectedWithAddress(int connectionId, string clientAddress)
{
if (IsConnectionAllowed(connectionId))
{
// create a connection
NetworkConnectionToClient conn = new NetworkConnectionToClient(connectionId, clientAddress);
OnConnected(conn);
}
else
{
// kick the client immediately
Transport.active.ServerDisconnect(connectionId);
}
}
static bool IsConnectionAllowed(int connectionId)
{
// connectionId needs to be != 0 because 0 is reserved for local player // connectionId needs to be != 0 because 0 is reserved for local player
// note that some transports like kcp generate connectionId by // note that some transports like kcp generate connectionId by
// hashing which can be < 0 as well, so we need to allow < 0! // hashing which can be < 0 as well, so we need to allow < 0!
if (connectionId == 0) if (connectionId == 0)
{ {
Debug.LogError($"Server.HandleConnect: invalid connectionId: {connectionId} . Needs to be != 0, because 0 is reserved for local player."); Debug.LogError($"Server.HandleConnect: invalid connectionId: {connectionId} . Needs to be != 0, because 0 is reserved for local player.");
Transport.active.ServerDisconnect(connectionId); return false;
return;
} }
// connectionId not in use yet? // connectionId not in use yet?
if (connections.ContainsKey(connectionId)) if (connections.ContainsKey(connectionId))
{ {
Transport.active.ServerDisconnect(connectionId); Debug.LogError($"Server connectionId {connectionId} already in use...client will be kicked");
// Debug.Log($"Server connectionId {connectionId} already in use...kicked client"); return false;
return;
} }
// are more connections allowed? if not, kick // are more connections allowed? if not, kick
@ -662,18 +677,13 @@ static void OnTransportConnected(int connectionId)
// less code and third party transport might not do that anyway) // less code and third party transport might not do that anyway)
// (this way we could also send a custom 'tooFull' message later, // (this way we could also send a custom 'tooFull' message later,
// Transport can't do that) // Transport can't do that)
if (connections.Count < maxConnections) if (connections.Count >= maxConnections)
{ {
// add connection Debug.LogError($"Server full, client {connectionId} will be kicked");
NetworkConnectionToClient conn = new NetworkConnectionToClient(connectionId); return false;
OnConnected(conn);
}
else
{
// kick
Transport.active.ServerDisconnect(connectionId);
// Debug.Log($"Server full, kicked client {connectionId}");
} }
return true;
} }
internal static void OnConnected(NetworkConnectionToClient conn) internal static void OnConnected(NetworkConnectionToClient conn)

View File

@ -68,6 +68,7 @@ public abstract class Transport : MonoBehaviour
// server ////////////////////////////////////////////////////////////// // server //////////////////////////////////////////////////////////////
/// <summary>Called by Transport when a new client connected to the server.</summary> /// <summary>Called by Transport when a new client connected to the server.</summary>
public Action<int> OnServerConnected; public Action<int> OnServerConnected;
public Action<int, string> OnServerConnectedWithAddress;
/// <summary>Called by Transport when the server received a message from a client.</summary> /// <summary>Called by Transport when the server received a message from a client.</summary>
public Action<int, ArraySegment<byte>, int> OnServerDataReceived; public Action<int, ArraySegment<byte>, int> OnServerDataReceived;

View File

@ -71,7 +71,9 @@ public void MaxConnections()
Assert.That(NetworkServer.connections.Count, Is.EqualTo(1)); Assert.That(NetworkServer.connections.Count, Is.EqualTo(1));
// connect second: should fail // connect second: should fail
LogAssert.ignoreFailingMessages = true;
transport.OnServerConnected.Invoke(43); transport.OnServerConnected.Invoke(43);
LogAssert.ignoreFailingMessages = false;
Assert.That(NetworkServer.connections.Count, Is.EqualTo(1)); Assert.That(NetworkServer.connections.Count, Is.EqualTo(1));
} }
@ -161,7 +163,9 @@ public void ConnectDuplicateConnectionIds()
NetworkConnectionToClient original = NetworkServer.connections[42]; NetworkConnectionToClient original = NetworkServer.connections[42];
// connect duplicate - shouldn't overwrite first one // connect duplicate - shouldn't overwrite first one
LogAssert.ignoreFailingMessages = true;
transport.OnServerConnected.Invoke(42); transport.OnServerConnected.Invoke(42);
LogAssert.ignoreFailingMessages = false;
Assert.That(NetworkServer.connections.Count, Is.EqualTo(1)); Assert.That(NetworkServer.connections.Count, Is.EqualTo(1));
Assert.That(NetworkServer.connections[42], Is.EqualTo(original)); Assert.That(NetworkServer.connections[42], Is.EqualTo(original));
} }

View File

@ -28,7 +28,7 @@ public class EdgegapKcpServer : KcpServer
bool relayActive; bool relayActive;
public EdgegapKcpServer( public EdgegapKcpServer(
Action<int> OnConnected, Action<int, IPEndPoint> OnConnected,
Action<int, ArraySegment<byte>, KcpChannel> OnData, Action<int, ArraySegment<byte>, KcpChannel> OnData,
Action<int> OnDisconnected, Action<int> OnDisconnected,
Action<int, ErrorCode, string> OnError, Action<int, ErrorCode, string> OnError,

View File

@ -60,7 +60,7 @@ protected override void Awake()
// server // server
server = new EdgegapKcpServer( server = new EdgegapKcpServer(
(connectionId) => OnServerConnected.Invoke(connectionId), (connectionId, endPoint) => OnServerConnectedWithAddress.Invoke(connectionId, endPoint.PrettyAddress()),
(connectionId, message, channel) => OnServerDataReceived.Invoke(connectionId, message, FromKcpChannel(channel)), (connectionId, message, channel) => OnServerDataReceived.Invoke(connectionId, message, FromKcpChannel(channel)),
(connectionId) => OnServerDisconnected.Invoke(connectionId), (connectionId) => OnServerDisconnected.Invoke(connectionId),
(connectionId, error, reason) => OnServerError.Invoke(connectionId, ToTransportError(error), reason), (connectionId, error, reason) => OnServerError.Invoke(connectionId, ToTransportError(error), reason),

View File

@ -122,7 +122,7 @@ protected virtual void Awake()
// server // server
server = new KcpServer( server = new KcpServer(
(connectionId) => OnServerConnected.Invoke(connectionId), (connectionId, endPoint) => OnServerConnectedWithAddress.Invoke(connectionId, endPoint.PrettyAddress()),
(connectionId, message, channel) => OnServerDataReceived.Invoke(connectionId, message, FromKcpChannel(channel)), (connectionId, message, channel) => OnServerDataReceived.Invoke(connectionId, message, FromKcpChannel(channel)),
(connectionId) => OnServerDisconnected.Invoke(connectionId), (connectionId) => OnServerDisconnected.Invoke(connectionId),
(connectionId, error, reason) => OnServerError.Invoke(connectionId, ToTransportError(error), reason), (connectionId, error, reason) => OnServerError.Invoke(connectionId, ToTransportError(error), reason),

View File

@ -18,7 +18,7 @@ public class KcpServer
// events are readonly, set in constructor. // events are readonly, set in constructor.
// this ensures they are always initialized when used. // this ensures they are always initialized when used.
// fixes https://github.com/MirrorNetworking/Mirror/issues/3337 and more // fixes https://github.com/MirrorNetworking/Mirror/issues/3337 and more
protected readonly Action<int> OnConnected; protected readonly Action<int, IPEndPoint> OnConnected; // connectionId, address
protected readonly Action<int, ArraySegment<byte>, KcpChannel> OnData; protected readonly Action<int, ArraySegment<byte>, KcpChannel> OnData;
protected readonly Action<int> OnDisconnected; protected readonly Action<int> OnDisconnected;
protected readonly Action<int, ErrorCode, string> OnError; protected readonly Action<int, ErrorCode, string> OnError;
@ -43,7 +43,7 @@ public class KcpServer
public Dictionary<int, KcpServerConnection> connections = public Dictionary<int, KcpServerConnection> connections =
new Dictionary<int, KcpServerConnection>(); new Dictionary<int, KcpServerConnection>();
public KcpServer(Action<int> OnConnected, public KcpServer(Action<int, IPEndPoint> OnConnected,
Action<int, ArraySegment<byte>, KcpChannel> OnData, Action<int, ArraySegment<byte>, KcpChannel> OnData,
Action<int> OnDisconnected, Action<int> OnDisconnected,
Action<int, ErrorCode, string> OnError, Action<int, ErrorCode, string> OnError,
@ -285,7 +285,8 @@ void OnConnectedCallback(KcpServerConnection conn)
// finally, call mirror OnConnected event // finally, call mirror OnConnected event
Log.Info($"[KCP] Server: OnConnected({connectionId})"); Log.Info($"[KCP] Server: OnConnected({connectionId})");
OnConnected(connectionId); IPEndPoint endPoint = conn.remoteEndPoint as IPEndPoint;
OnConnected(connectionId, endPoint);
} }
void OnDisconnectedCallback() void OnDisconnectedCallback()

View File

@ -6,6 +6,7 @@
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Net;
using System.Threading; using System.Threading;
using UnityEngine; using UnityEngine;
@ -317,9 +318,11 @@ protected void OnThreadedClientDisconnected()
EnqueueClientMain(ClientMainEventType.OnClientDisconnected, null, null, null); EnqueueClientMain(ClientMainEventType.OnClientDisconnected, null, null, null);
} }
protected void OnThreadedServerConnected(int connectionId) protected void OnThreadedServerConnected(int connectionId, IPEndPoint endPoint)
{ {
EnqueueServerMain(ServerMainEventType.OnServerConnected, null, connectionId, null, null); // create string copy of address immediately before sending to another thread
string address = endPoint.PrettyAddress();
EnqueueServerMain(ServerMainEventType.OnServerConnected, address, connectionId, null, null);
} }
protected void OnThreadedServerSend(int connectionId, ArraySegment<byte> message, int channelId) protected void OnThreadedServerSend(int connectionId, ArraySegment<byte> message, int channelId)
@ -515,9 +518,9 @@ public override void ServerEarlyUpdate()
// SERVER EVENTS /////////////////////////////////////////// // SERVER EVENTS ///////////////////////////////////////////
case ServerMainEventType.OnServerConnected: case ServerMainEventType.OnServerConnected:
{ {
// call original transport event // call original transport event with connectionId, address
// TODO pass client address in OnConnect here later string address = (string)elem.param;
OnServerConnected?.Invoke(elem.connectionId.Value);//, (string)elem.param); OnServerConnectedWithAddress?.Invoke(elem.connectionId.Value, address);
break; break;
} }
case ServerMainEventType.OnServerSent: case ServerMainEventType.OnServerSent:
@ -612,7 +615,7 @@ public override void ServerDisconnect(int connectionId)
// querying this at runtime won't work for threaded transports. // querying this at runtime won't work for threaded transports.
public override string ServerGetClientAddress(int connectionId) public override string ServerGetClientAddress(int connectionId)
{ {
throw new NotImplementedException(); throw new NotImplementedException("ThreadedTransport passes each connection's address in OnServerConnectedThreaded. Don't use ServerGetClientAddress.");
} }
public override void ServerStop() public override void ServerStop()