feat: Websockets now give client address, fix #1121 (#1125)

This commit is contained in:
Paul Pacheco 2019-09-28 18:27:46 -05:00 committed by GitHub
parent 2cd36c8b58
commit c9f317ddee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 7 deletions

View File

@ -1,4 +1,5 @@
using System.IO; using System.IO;
using System.Net.Sockets;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -16,7 +17,7 @@ public interface IWebSocketServerFactory
/// <param name="stream">The network stream</param> /// <param name="stream">The network stream</param>
/// <param name="token">The optional cancellation token</param> /// <param name="token">The optional cancellation token</param>
/// <returns>Http data read from the stream</returns> /// <returns>Http data read from the stream</returns>
Task<WebSocketHttpContext> ReadHttpHeaderFromStreamAsync(Stream stream, CancellationToken token = default(CancellationToken)); Task<WebSocketHttpContext> ReadHttpHeaderFromStreamAsync(TcpClient client, Stream stream, CancellationToken token = default(CancellationToken));
/// <summary> /// <summary>
/// Accept web socket with default options /// Accept web socket with default options

View File

@ -65,6 +65,8 @@ internal class WebSocketImplementation : WebSocket
Queue<ArraySegment<byte>> _messageQueue = new Queue<ArraySegment<byte>>(); Queue<ArraySegment<byte>> _messageQueue = new Queue<ArraySegment<byte>>();
SemaphoreSlim _sendSemaphore = new SemaphoreSlim(1, 1); SemaphoreSlim _sendSemaphore = new SemaphoreSlim(1, 1);
public WebSocketHttpContext Context { get; set; }
internal WebSocketImplementation(Guid guid, Func<MemoryStream> recycledStreamFactory, Stream stream, TimeSpan keepAliveInterval, string secWebSocketExtensions, bool includeExceptionInCloseResponse, bool isClient, string subProtocol) internal WebSocketImplementation(Guid guid, Func<MemoryStream> recycledStreamFactory, Stream stream, TimeSpan keepAliveInterval, string secWebSocketExtensions, bool includeExceptionInCloseResponse, bool isClient, string subProtocol)
{ {

View File

@ -1,5 +1,6 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net.Sockets;
namespace Ninja.WebSockets namespace Ninja.WebSockets
{ {
@ -30,6 +31,11 @@ public class WebSocketHttpContext
/// </summary> /// </summary>
public Stream Stream { get; private set; } public Stream Stream { get; private set; }
/// <summary>
/// The tcp connection we are using
/// </summary>
public TcpClient Client { get; private set; }
/// <summary> /// <summary>
/// Initialises a new instance of the WebSocketHttpContext class /// Initialises a new instance of the WebSocketHttpContext class
/// </summary> /// </summary>
@ -37,12 +43,13 @@ public class WebSocketHttpContext
/// <param name="httpHeader">The raw http header extracted from the stream</param> /// <param name="httpHeader">The raw http header extracted from the stream</param>
/// <param name="path">The Path extracted from the http header</param> /// <param name="path">The Path extracted from the http header</param>
/// <param name="stream">The stream AFTER the header has already been read</param> /// <param name="stream">The stream AFTER the header has already been read</param>
public WebSocketHttpContext(bool isWebSocketRequest, IList<string> webSocketRequestedProtocols, string httpHeader, string path, Stream stream) public WebSocketHttpContext(bool isWebSocketRequest, IList<string> webSocketRequestedProtocols, string httpHeader, string path, TcpClient client, Stream stream)
{ {
IsWebSocketRequest = isWebSocketRequest; IsWebSocketRequest = isWebSocketRequest;
WebSocketRequestedProtocols = webSocketRequestedProtocols; WebSocketRequestedProtocols = webSocketRequestedProtocols;
HttpHeader = httpHeader; HttpHeader = httpHeader;
Path = path; Path = path;
Client = client;
Stream = stream; Stream = stream;
} }
} }

View File

@ -23,6 +23,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net.Sockets;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using System.Threading; using System.Threading;
@ -64,13 +65,13 @@ public WebSocketServerFactory(Func<MemoryStream> bufferFactory)
/// <param name="stream">The network stream</param> /// <param name="stream">The network stream</param>
/// <param name="token">The optional cancellation token</param> /// <param name="token">The optional cancellation token</param>
/// <returns>Http data read from the stream</returns> /// <returns>Http data read from the stream</returns>
public async Task<WebSocketHttpContext> ReadHttpHeaderFromStreamAsync(Stream stream, CancellationToken token = default(CancellationToken)) public async Task<WebSocketHttpContext> ReadHttpHeaderFromStreamAsync(TcpClient client, Stream stream, CancellationToken token = default(CancellationToken))
{ {
string header = await HttpHelper.ReadHttpHeaderAsync(stream, token); string header = await HttpHelper.ReadHttpHeaderAsync(stream, token);
string path = HttpHelper.GetPathFromHeader(header); string path = HttpHelper.GetPathFromHeader(header);
bool isWebSocketRequest = HttpHelper.IsWebSocketUpgradeRequest(header); bool isWebSocketRequest = HttpHelper.IsWebSocketUpgradeRequest(header);
IList<string> subProtocols = HttpHelper.GetSubProtocols(header); IList<string> subProtocols = HttpHelper.GetSubProtocols(header);
return new WebSocketHttpContext(isWebSocketRequest, subProtocols, header, path, stream); return new WebSocketHttpContext(isWebSocketRequest, subProtocols, header, path, client, stream);
} }
/// <summary> /// <summary>
@ -100,7 +101,10 @@ public WebSocketServerFactory(Func<MemoryStream> bufferFactory)
await PerformHandshakeAsync(guid, context.HttpHeader, options.SubProtocol, context.Stream, token); await PerformHandshakeAsync(guid, context.HttpHeader, options.SubProtocol, context.Stream, token);
Events.Log.ServerHandshakeSuccess(guid); Events.Log.ServerHandshakeSuccess(guid);
string secWebSocketExtensions = null; string secWebSocketExtensions = null;
return new WebSocketImplementation(guid, _bufferFactory, context.Stream, options.KeepAliveInterval, secWebSocketExtensions, options.IncludeExceptionInCloseResponse, false, options.SubProtocol); return new WebSocketImplementation(guid, _bufferFactory, context.Stream, options.KeepAliveInterval, secWebSocketExtensions, options.IncludeExceptionInCloseResponse, false, options.SubProtocol)
{
Context = context
};
} }
static void CheckWebSocketVersion(string httpHeader) static void CheckWebSocketVersion(string httpHeader)

View File

@ -9,6 +9,7 @@
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Ninja.WebSockets; using Ninja.WebSockets;
using Ninja.WebSockets.Internal;
using UnityEngine; using UnityEngine;
namespace Mirror.Websocket namespace Mirror.Websocket
@ -128,7 +129,7 @@ async Task ProcessTcpClient(TcpClient tcpClient, CancellationToken token)
sslStream.AuthenticateAsServer(_sslConfig.Certificate, _sslConfig.ClientCertificateRequired, _sslConfig.EnabledSslProtocols, _sslConfig.CheckCertificateRevocation); sslStream.AuthenticateAsServer(_sslConfig.Certificate, _sslConfig.ClientCertificateRequired, _sslConfig.EnabledSslProtocols, _sslConfig.CheckCertificateRevocation);
stream = sslStream; stream = sslStream;
} }
WebSocketHttpContext context = await webSocketServerFactory.ReadHttpHeaderFromStreamAsync(stream, token); WebSocketHttpContext context = await webSocketServerFactory.ReadHttpHeaderFromStreamAsync(tcpClient, stream, token);
if (context.IsWebSocketRequest) if (context.IsWebSocketRequest)
{ {
WebSocketServerOptions options = new WebSocketServerOptions() { KeepAliveInterval = TimeSpan.FromSeconds(30), SubProtocol = "binary" }; WebSocketServerOptions options = new WebSocketServerOptions() { KeepAliveInterval = TimeSpan.FromSeconds(30), SubProtocol = "binary" };
@ -312,7 +313,9 @@ public string GetClientAddress(int connectionId)
// find the connection // find the connection
if (clients.TryGetValue(connectionId, out WebSocket client)) if (clients.TryGetValue(connectionId, out WebSocket client))
{ {
return ""; WebSocketImplementation wsClient = client as WebSocketImplementation;
return wsClient.Context.Client.Client.RemoteEndPoint.ToString();
} }
return null; return null;
} }