fix: NetworkServer/Client don't use Connect/DisconnectMessage anymore. fixes a bug where external connects could try sending that message causing undefined behaviour (#2577)

This commit is contained in:
vis2k 2021-02-14 00:09:42 +08:00
parent 5aed823cc7
commit be42976a26
11 changed files with 70 additions and 94 deletions

View File

@ -87,6 +87,12 @@ internal class ULocalConnectionToServer : NetworkConnectionToServer
public override string address => "localhost";
// see caller for comments on why we need this
bool connectedEventPending;
bool disconnectedEventPending;
internal void QueueConnectedEvent() => connectedEventPending = true;
internal void QueueDisconnectedEvent() => disconnectedEventPending = true;
internal override void Send(ArraySegment<byte> segment, int channelId = Channels.DefaultReliable)
{
if (segment.Count == 0)
@ -101,6 +107,13 @@ internal override void Send(ArraySegment<byte> segment, int channelId = Channels
internal void Update()
{
// should we still process a connected event?
if (connectedEventPending)
{
connectedEventPending = false;
NetworkClient.OnConnectedEvent?.Invoke(this);
}
// process internal messages so they are applied at the correct time
while (buffer.HasPackets())
{
@ -112,6 +125,13 @@ internal void Update()
}
buffer.ResetBuffer();
// should we still process a disconnected event?
if (disconnectedEventPending)
{
disconnectedEventPending = false;
NetworkClient.OnDisconnectedEvent?.Invoke(this);
}
}
/// <summary>

View File

@ -30,10 +30,6 @@ public struct NotReadyMessage : NetworkMessage { }
public struct AddPlayerMessage : NetworkMessage { }
public struct DisconnectMessage : NetworkMessage { }
public struct ConnectMessage : NetworkMessage { }
public struct SceneMessage : NetworkMessage
{
public string sceneName;

View File

@ -56,6 +56,13 @@ public static class NetworkClient
/// </summary>
public static bool isLocalClient => connection is ULocalConnectionToServer;
// OnConnected / OnDisconnected used to be NetworkMessages that were
// invoked. this introduced a bug where external clients could send
// Connected/Disconnected messages over the network causing undefined
// behaviour.
internal static Action<NetworkConnection> OnConnectedEvent;
internal static Action<NetworkConnection> OnDisconnectedEvent;
/// <summary>
/// Connect client to a NetworkServer instance.
/// </summary>
@ -124,8 +131,18 @@ public static void ConnectHost()
/// </summary>
public static void ConnectLocalServer()
{
// call server OnConnected with server's connection to client
NetworkServer.OnConnected(NetworkServer.localConnection);
NetworkServer.localConnection.Send(new ConnectMessage());
// call client OnConnected with client's connection to server
// => previously we used to send a ConnectMessage to
// NetworkServer.localConnection. this would queue the message
// until NetworkClient.Update processes it.
// => invoking the client's OnConnected event directly here makes
// tests fail. so let's do it exactly the same order as before by
// queueing the event for next Update!
//OnConnectedEvent?.Invoke(connection);
((ULocalConnectionToServer)connection).QueueConnectedEvent();
}
/// <summary>
@ -164,7 +181,7 @@ static void OnDisconnected()
ClientScene.HandleClientDisconnect(connection);
connection?.InvokeHandler(new DisconnectMessage(), -1);
if (connection != null) OnDisconnectedEvent?.Invoke(connection);
}
internal static void OnDataReceived(ArraySegment<byte> data, int channelId)
@ -187,7 +204,7 @@ static void OnConnected()
// thus we should set the connected state before calling the handler
connectState = ConnectState.Connected;
NetworkTime.UpdateClient();
connection.InvokeHandler(new ConnectMessage(), -1);
OnConnectedEvent?.Invoke(connection);
}
else logger.LogError("Skipped Connect message handling because connection is null.");
}
@ -206,7 +223,15 @@ public static void Disconnect()
{
if (isConnected)
{
NetworkServer.localConnection.Send(new DisconnectMessage());
// call client OnDisconnected with connection to server
// => previously we used to send a DisconnectMessage to
// NetworkServer.localConnection. this would queue the
// message until NetworkClient.Update processes it.
// => invoking the client's OnDisconnected event directly
// here makes tests fail. so let's do it exactly the same
// order as before by queueing the event for next Update!
//OnDisconnectedEvent?.Invoke(connection);
((ULocalConnectionToServer)connection).QueueDisconnectedEvent();
}
NetworkServer.RemoveLocalConnection();
}

View File

@ -776,8 +776,8 @@ bool InitializeSingleton()
void RegisterServerMessages()
{
NetworkServer.RegisterHandler<ConnectMessage>(OnServerConnectInternal, false);
NetworkServer.RegisterHandler<DisconnectMessage>(OnServerDisconnectInternal, false);
NetworkServer.OnConnectedEvent = OnServerConnectInternal;
NetworkServer.OnDisconnectedEvent = OnServerDisconnectInternal;
NetworkServer.RegisterHandler<AddPlayerMessage>(OnServerAddPlayerInternal);
NetworkServer.RegisterHandler<ErrorMessage>(OnServerErrorInternal, false);
@ -787,8 +787,8 @@ void RegisterServerMessages()
void RegisterClientMessages()
{
NetworkClient.RegisterHandler<ConnectMessage>(OnClientConnectInternal, false);
NetworkClient.RegisterHandler<DisconnectMessage>(OnClientDisconnectInternal, false);
NetworkClient.OnConnectedEvent = OnClientConnectInternal;
NetworkClient.OnDisconnectedEvent = OnClientDisconnectInternal;
NetworkClient.RegisterHandler<NotReadyMessage>(OnClientNotReadyMessageInternal);
NetworkClient.RegisterHandler<ErrorMessage>(OnClientErrorInternal, false);
NetworkClient.RegisterHandler<SceneMessage>(OnClientSceneInternal, false);
@ -1168,7 +1168,7 @@ public Transform GetStartPosition()
#region Server Internal Message Handlers
void OnServerConnectInternal(NetworkConnection conn, ConnectMessage connectMsg)
void OnServerConnectInternal(NetworkConnection conn)
{
logger.Log("NetworkManager.OnServerConnectInternal");
@ -1202,7 +1202,7 @@ void OnServerAuthenticated(NetworkConnection conn)
OnServerConnect(conn);
}
void OnServerDisconnectInternal(NetworkConnection conn, DisconnectMessage msg)
void OnServerDisconnectInternal(NetworkConnection conn)
{
logger.Log("NetworkManager.OnServerDisconnectInternal");
OnServerDisconnect(conn);
@ -1249,7 +1249,7 @@ void OnServerErrorInternal(NetworkConnection conn, ErrorMessage msg)
#region Client Internal Message Handlers
void OnClientConnectInternal(NetworkConnection conn, ConnectMessage message)
void OnClientConnectInternal(NetworkConnection conn)
{
logger.Log("NetworkManager.OnClientConnectInternal");
@ -1287,7 +1287,7 @@ void OnClientAuthenticated(NetworkConnection conn)
}
}
void OnClientDisconnectInternal(NetworkConnection conn, DisconnectMessage msg)
void OnClientDisconnectInternal(NetworkConnection conn)
{
logger.Log("NetworkManager.OnClientDisconnectInternal");
OnClientDisconnect(conn);

View File

@ -91,6 +91,13 @@ public static class NetworkServer
/// </summary>
public static float disconnectInactiveTimeout = 60f;
// OnConnected / OnDisconnected used to be NetworkMessages that were
// invoked. this introduced a bug where external clients could send
// Connected/Disconnected messages over the network causing undefined
// behaviour.
internal static Action<NetworkConnection> OnConnectedEvent;
internal static Action<NetworkConnection> OnDisconnectedEvent;
/// <summary>
/// This shuts down the server and disconnects all clients.
/// </summary>
@ -567,7 +574,7 @@ internal static void OnConnected(NetworkConnectionToClient conn)
// add connection and invoke connected event
AddConnection(conn);
conn.InvokeHandler(new ConnectMessage(), -1);
OnConnectedEvent?.Invoke(conn);
}
internal static void OnDisconnected(int connectionId)
@ -586,7 +593,7 @@ internal static void OnDisconnected(int connectionId)
static void OnDisconnected(NetworkConnection conn)
{
conn.InvokeHandler(new DisconnectMessage(), -1);
OnDisconnectedEvent?.Invoke(conn);
if (logger.LogEnabled()) logger.Log("Server lost client:" + conn);
}

View File

@ -30,30 +30,6 @@ public void CommandMessage()
Is.EqualTo(message.payload.Array[message.payload.Offset + i]));
}
[Test]
public void ConnectMessage()
{
// try setting value with constructor
ConnectMessage message = new ConnectMessage();
byte[] arr = MessagePackerTest.PackToByteArray(message);
Assert.DoesNotThrow(() =>
{
MessagePackerTest.UnpackFromByteArray<ConnectMessage>(arr);
});
}
[Test]
public void DisconnectMessage()
{
// try setting value with constructor
DisconnectMessage message = new DisconnectMessage();
byte[] arr = MessagePackerTest.PackToByteArray(message);
Assert.DoesNotThrow(() =>
{
MessagePackerTest.UnpackFromByteArray<DisconnectMessage>(arr);
});
}
[Test]
public void ErrorMessage()
{

View File

@ -54,13 +54,13 @@ public void TestPacking()
[Test]
public void UnpackWrongMessage()
{
ConnectMessage message = new ConnectMessage();
SpawnMessage message = new SpawnMessage();
byte[] data = PackToByteArray(message);
Assert.Throws<FormatException>(() =>
{
DisconnectMessage unpacked = UnpackFromByteArray<DisconnectMessage>(data);
UpdateVarsMessage unpacked = UnpackFromByteArray<UpdateVarsMessage>(data);
});
}

View File

@ -300,8 +300,6 @@ public void SendCommandInternal()
// we need to start a server and connect a client in order to be
// able to send commands
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<SpawnMessage>((conn, msg) => { }, false);
NetworkServer.Listen(1);
@ -413,8 +411,6 @@ public void SendRPCInternal()
// we need to start a server and connect a client in order to be
// able to send commands
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<SpawnMessage>((conn, msg) => { }, false);
NetworkServer.Listen(1);
@ -499,8 +495,6 @@ public void SendTargetRPCInternal()
// we need to start a server and connect a client in order to be
// able to send commands
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<SpawnMessage>((conn, msg) => { }, false);
NetworkServer.Listen(1);

View File

@ -17,13 +17,8 @@ public void SetUp()
Transport.activeTransport = transportGO.AddComponent<MemoryTransport>();
// we need a server to connect to
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
NetworkServer.Listen(10);
// setup client handlers too
NetworkClient.RegisterHandler<ConnectMessage>(msg => { }, false);
}
[TearDown]

View File

@ -844,8 +844,6 @@ public void OnStartServerInHostModeSetsIsClientTrue()
{
// call client connect so that internals are set up
// (it won't actually successfully connect)
// -> also set up connectmessage handler to avoid unhandled msg error
NetworkClient.RegisterHandler<ConnectMessage>(msg => { }, false);
NetworkClient.Connect("localhost");
// manually invoke transport.OnConnected so that NetworkClient.active is set to true

View File

@ -119,8 +119,6 @@ public void IsActiveTest()
public void MaxConnectionsTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen with maxconnections=1
@ -141,8 +139,7 @@ public void ConnectMessageHandlerTest()
{
// message handlers
bool connectCalled = false;
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { connectCalled = true; }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.OnConnectedEvent = conn => connectCalled = true;
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -159,8 +156,7 @@ public void DisconnectMessageHandlerTest()
{
// message handlers
bool disconnectCalled = false;
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { disconnectCalled = true; }, false);
NetworkServer.OnDisconnectedEvent = conn => disconnectCalled = true;
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -180,8 +176,6 @@ public void DisconnectMessageHandlerTest()
public void ConnectionsDictTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -216,8 +210,6 @@ public void OnConnectedOnlyAllowsNonZeroConnectionIdsTest()
// <0 is never used
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -240,8 +232,6 @@ public void OnConnectedOnlyAllowsNonZeroConnectionIdsTest()
public void ConnectDuplicateConnectionIdsTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -315,8 +305,6 @@ public void LocalClientActiveTest()
public void AddConnectionTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -356,8 +344,6 @@ public void AddConnectionTest()
public void RemoveConnectionTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -382,8 +368,6 @@ public void RemoveConnectionTest()
public void DisconnectAllConnectionsTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -404,8 +388,6 @@ public void DisconnectAllConnectionsTest()
public void DisconnectAllTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -432,8 +414,6 @@ public void DisconnectAllTest()
public void OnDataReceivedTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// add one custom message handler
@ -478,8 +458,6 @@ public void OnDataReceivedTest()
public void OnDataReceivedInvalidConnectionIdTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// add one custom message handler
@ -722,8 +700,6 @@ public void ActivateHostSceneCallsOnStartClient()
public void SendToAllTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -758,8 +734,6 @@ public void SendToAllTest()
public void RegisterUnregisterClearHandlerTest()
{
// message handlers that are needed for the test
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
@ -806,7 +780,6 @@ public void RegisterUnregisterClearHandlerTest()
// unregister second handler via ClearHandlers to test that one too. send, should fail
NetworkServer.ClearHandlers();
// (only add this one to avoid disconnect error)
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
writer = new NetworkWriter();
MessagePacker.Pack(new TestMessage1(), writer);
// log error messages are expected
@ -821,8 +794,6 @@ public void RegisterUnregisterClearHandlerTest()
public void SendToClientOfPlayer()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -899,8 +870,6 @@ public void GetNetworkIdentityErrorIfNotFound()
public void ShowForConnection()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -951,8 +920,6 @@ public void ShowForConnection()
public void HideForConnection()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen
@ -1089,8 +1056,6 @@ public void UnSpawn()
public void ShutdownCleanupTest()
{
// message handlers
NetworkServer.RegisterHandler<ConnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<DisconnectMessage>((conn, msg) => { }, false);
NetworkServer.RegisterHandler<ErrorMessage>((conn, msg) => { }, false);
// listen