fix: Allowing overrides for virtual commands to call base method (#1944)

* adding override method that calls base

* trying to debug instruction for call to base

* extra tests

* adding tests for client and target Rpc

* adding fix for calls to base class

Since networkbehaviour parents are processed first we can just fix the
method when we see it in SubstituteMethod
This commit is contained in:
James Frowen 2020-06-10 15:11:29 +01:00 committed by GitHub
parent 2ce5880646
commit b92da91d7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 265 additions and 0 deletions

View File

@ -54,7 +54,58 @@ public static MethodDefinition SubstituteMethod(TypeDefinition td, MethodDefinit
(md.DebugInformation.Scope, cmd.DebugInformation.Scope) = (cmd.DebugInformation.Scope, md.DebugInformation.Scope);
td.Methods.Add(cmd);
FixRemoteCallToBaseMethod(td, cmd);
return cmd;
}
/// <summary>
/// Finds and fixes call to base methods within remote calls
/// <para>For example, changes `base.CmdDoSomething` to `base.CallCmdDoSomething` within `this.CallCmdDoSomething`</para>
/// </summary>
/// <param name="type"></param>
/// <param name="method"></param>
public static void FixRemoteCallToBaseMethod(TypeDefinition type, MethodDefinition method)
{
string callName = method.Name;
// all Commands/Rpc start with "Call"
// eg CallCmdDoSomething
if (!callName.StartsWith("Call"))
return;
// eg CmdDoSomething
string baseRemoteCallName = method.Name.Substring(4);
foreach (Instruction instruction in method.Body.Instructions)
{
// if call to base.CmdDoSomething within this.CallCmdDoSomething
if (IsCallToMethod(instruction, out MethodDefinition calledMethod) &&
calledMethod.Name == baseRemoteCallName)
{
TypeDefinition baseType = type.BaseType.Resolve();
MethodDefinition baseMethod = baseType.GetMethod(callName);
instruction.Operand = baseMethod;
Weaver.DLog(type, "Replacing call to '{0}' with '{1}' inside '{2}'", calledMethod.FullName, baseMethod.FullName, method.FullName);
}
}
}
static bool IsCallToMethod(Instruction instruction, out MethodDefinition calledMethod)
{
if (instruction.OpCode == OpCodes.Call &&
instruction.Operand is MethodDefinition method)
{
calledMethod = method;
return true;
}
else
{
calledMethod = null;
return false;
}
}
}
}

View File

@ -14,6 +14,11 @@ public virtual void RpcSendInt(int someInt)
}
}
class VirtualNoOverrideClientRpc : VirtualClientRpc
{
}
class VirtualOverrideClientRpc : VirtualClientRpc
{
public event Action<int> onOverrideSendInt;
@ -25,6 +30,18 @@ public override void RpcSendInt(int someInt)
}
}
class VirtualOverrideClientRpcWithBase : VirtualClientRpc
{
public event Action<int> onOverrideSendInt;
[ClientRpc]
public override void RpcSendInt(int someInt)
{
base.RpcSendInt(someInt);
onOverrideSendInt?.Invoke(someInt);
}
}
public class ClientRpcOverrideTest : RemoteTestBase
{
[Test]
@ -46,6 +63,24 @@ public void VirtualRpcIsCalled()
Assert.That(virtualCallCount, Is.EqualTo(1));
}
[Test]
public void VirtualCommandWithNoOverrideIsCalled()
{
VirtualNoOverrideClientRpc hostBehaviour = CreateHostObject<VirtualNoOverrideClientRpc>(true);
const int someInt = 20;
int virtualCallCount = 0;
hostBehaviour.onVirtualSendInt += incomingInt =>
{
virtualCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.RpcSendInt(someInt);
ProcessMessages();
Assert.That(virtualCallCount, Is.EqualTo(1));
}
[Test]
public void OverrideVirtualRpcIsCalled()
@ -71,5 +106,31 @@ public void OverrideVirtualRpcIsCalled()
Assert.That(virtualCallCount, Is.EqualTo(0));
Assert.That(overrideCallCount, Is.EqualTo(1));
}
[Test]
public void OverrideVirtualWithBaseCallsBothVirtualAndBase()
{
VirtualOverrideClientRpcWithBase hostBehaviour = CreateHostObject<VirtualOverrideClientRpcWithBase>(true);
const int someInt = 20;
int virtualCallCount = 0;
int overrideCallCount = 0;
hostBehaviour.onVirtualSendInt += incomingInt =>
{
virtualCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.onOverrideSendInt += incomingInt =>
{
overrideCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.RpcSendInt(someInt);
ProcessMessages();
Assert.That(virtualCallCount, Is.EqualTo(1));
Assert.That(overrideCallCount, Is.EqualTo(1));
}
}
}

View File

@ -14,6 +14,11 @@ public virtual void CmdSendInt(int someInt)
}
}
class VirtualNoOverrideCommand : VirtualCommand
{
}
class VirtualOverrideCommand : VirtualCommand
{
public event Action<int> onOverrideSendInt;
@ -25,6 +30,18 @@ public override void CmdSendInt(int someInt)
}
}
class VirtualOverrideCommandWithBase : VirtualCommand
{
public event Action<int> onOverrideSendInt;
[Command]
public override void CmdSendInt(int someInt)
{
base.CmdSendInt(someInt);
onOverrideSendInt?.Invoke(someInt);
}
}
public class CommandOverrideTest : RemoteTestBase
{
[Test]
@ -46,6 +63,24 @@ public void VirtualCommandIsCalled()
Assert.That(virtualCallCount, Is.EqualTo(1));
}
[Test]
public void VirtualCommandWithNoOverrideIsCalled()
{
VirtualNoOverrideCommand hostBehaviour = CreateHostObject<VirtualNoOverrideCommand>(true);
const int someInt = 20;
int virtualCallCount = 0;
hostBehaviour.onVirtualSendInt += incomingInt =>
{
virtualCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.CmdSendInt(someInt);
ProcessMessages();
Assert.That(virtualCallCount, Is.EqualTo(1));
}
[Test]
public void OverrideVirtualCommandIsCalled()
@ -71,5 +106,31 @@ public void OverrideVirtualCommandIsCalled()
Assert.That(virtualCallCount, Is.EqualTo(0));
Assert.That(overrideCallCount, Is.EqualTo(1));
}
[Test]
public void OverrideVirtualWithBaseCallsBothVirtualAndBase()
{
VirtualOverrideCommandWithBase hostBehaviour = CreateHostObject<VirtualOverrideCommandWithBase>(true);
const int someInt = 20;
int virtualCallCount = 0;
int overrideCallCount = 0;
hostBehaviour.onVirtualSendInt += incomingInt =>
{
virtualCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.onOverrideSendInt += incomingInt =>
{
overrideCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.CmdSendInt(someInt);
ProcessMessages();
Assert.That(virtualCallCount, Is.EqualTo(1));
Assert.That(overrideCallCount, Is.EqualTo(1));
}
}
}

View File

@ -14,6 +14,11 @@ public virtual void TargetSendInt(int someInt)
}
}
class VirtualNoOverrideTargetRpc : VirtualTargetRpc
{
}
class VirtualOverrideTargetRpc : VirtualTargetRpc
{
public event Action<int> onOverrideSendInt;
@ -25,6 +30,18 @@ public override void TargetSendInt(int someInt)
}
}
class VirtualOverrideTargetRpcWithBase : VirtualTargetRpc
{
public event Action<int> onOverrideSendInt;
[TargetRpc]
public override void TargetSendInt(int someInt)
{
base.TargetSendInt(someInt);
onOverrideSendInt?.Invoke(someInt);
}
}
public class TargetRpcOverrideTest : RemoteTestBase
{
[Test]
@ -46,6 +63,24 @@ public void VirtualRpcIsCalled()
Assert.That(virtualCallCount, Is.EqualTo(1));
}
[Test]
public void VirtualCommandWithNoOverrideIsCalled()
{
VirtualNoOverrideTargetRpc hostBehaviour = CreateHostObject<VirtualNoOverrideTargetRpc>(true);
const int someInt = 20;
int virtualCallCount = 0;
hostBehaviour.onVirtualSendInt += incomingInt =>
{
virtualCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.TargetSendInt(someInt);
ProcessMessages();
Assert.That(virtualCallCount, Is.EqualTo(1));
}
[Test]
public void OverrideVirtualRpcIsCalled()
@ -71,5 +106,31 @@ public void OverrideVirtualRpcIsCalled()
Assert.That(virtualCallCount, Is.EqualTo(0));
Assert.That(overrideCallCount, Is.EqualTo(1));
}
[Test]
public void OverrideVirtualWithBaseCallsBothVirtualAndBase()
{
VirtualOverrideTargetRpcWithBase hostBehaviour = CreateHostObject<VirtualOverrideTargetRpcWithBase>(true);
const int someInt = 20;
int virtualCallCount = 0;
int overrideCallCount = 0;
hostBehaviour.onVirtualSendInt += incomingInt =>
{
virtualCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.onOverrideSendInt += incomingInt =>
{
overrideCallCount++;
Assert.That(incomingInt, Is.EqualTo(someInt));
};
hostBehaviour.TargetSendInt(someInt);
ProcessMessages();
Assert.That(virtualCallCount, Is.EqualTo(1));
Assert.That(overrideCallCount, Is.EqualTo(1));
}
}
}

View File

@ -99,6 +99,7 @@
<Compile Include="WeaverCommandTests~\CommandValid.cs" />
<Compile Include="WeaverCommandTests~\CommandWithArguments.cs" />
<Compile Include="WeaverCommandTests~\OverrideAbstractCommand.cs" />
<Compile Include="WeaverCommandTests~\OverrideVirtualCallBaseCommand.cs" />
<Compile Include="WeaverCommandTests~\OverrideVirtualCommand.cs" />
<Compile Include="WeaverCommandTests~\VirtualCommand.cs" />
<Compile Include="WeaverCommandTests~\CommandWithSenderConnectionAndOtherArgs.cs" />

View File

@ -70,6 +70,12 @@ public void OverrideVirtualCommand()
Assert.That(weaverErrors, Is.Empty);
}
[Test]
public void OverrideVirtualCallBaseCommand()
{
Assert.That(weaverErrors, Is.Empty);
}
[Test]
public void AbstractCommand()
{

View File

@ -0,0 +1,24 @@
using Mirror;
namespace WeaverCommandTests.OverrideVirtualCallBaseCommand
{
class OverrideVirtualCallBaseCommand : baseBehaviour
{
[Command]
protected override void CmdDoSomething()
{
// do somethin
base.CmdDoSomething();
}
}
class baseBehaviour : NetworkBehaviour
{
[Command]
protected virtual void CmdDoSomething()
{
// do more stuff
}
}
}