diff --git a/Assets/NetworkLobbyClient/package.json b/Assets/NetworkLobbyClient/package.json index 3cb11a7..0dfb399 100644 --- a/Assets/NetworkLobbyClient/package.json +++ b/Assets/NetworkLobbyClient/package.json @@ -1,6 +1,6 @@ { "name": "com.incobyte.lobbyclient", - "version": "1.0.6", + "version": "1.0.7", "displayName": "Game Lobby Client", "description": "Provides a client for the game lobby server to list and join lobbies", "unity": "2022.3", diff --git a/LobbyClient/LobbyClient.cs b/LobbyClient/LobbyClient.cs index 0a11fe4..5b6c407 100644 --- a/LobbyClient/LobbyClient.cs +++ b/LobbyClient/LobbyClient.cs @@ -21,7 +21,7 @@ namespace Lobbies UdpEchoServer udpEchoServer = new UdpEchoServer(); - private Dictionary lobbyInformation = new Dictionary(); + private ConcurrentDictionary lobbyInformation = new ConcurrentDictionary(); private string? host; private int port; private int connectionId; @@ -29,6 +29,8 @@ namespace Lobbies public string? externalIp; public int externalPort; + private CancellationTokenRegistration connectionCancellationRegistration; + public void Connect(string host, int port, CancellationToken cancellationToken) { this.host = host; @@ -37,9 +39,10 @@ namespace Lobbies try { cancellationToken.ThrowIfCancellationRequested(); - using var cts = cancellationToken.Register(() => + connectionCancellationRegistration.Dispose(); + connectionCancellationRegistration = cancellationToken.Register(() => { - tcpClient.Stop(); + tcpClient.Stop(); }); tcpClient.DataReceived -= TcpClient_DataReceived; @@ -56,7 +59,7 @@ namespace Lobbies } } - public bool TryReadEvent(out LobbyClientEvent result) + public bool TryReadEvent(out LobbyClientEvent? result) { return events.TryDequeue(out result); } @@ -87,8 +90,16 @@ namespace Lobbies }; byte[] messageData = bufferRental.Rent(); - var len = lobbyCreate.Serialize(messageData); - await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); + try + { + + var len = lobbyCreate.Serialize(messageData); + await tcpClient.Send(messageData, 0, len); + } + finally + { + bufferRental.Return(messageData); + } }); } @@ -96,8 +107,11 @@ namespace Lobbies { _ = Task.Run(async () => { byte[]? passwordHash = null; - if (!string.IsNullOrEmpty(password) && lobbyInformation.ContainsKey(lobbyId) && lobbyInformation[lobbyId].PasswordSalt != null) - passwordHash = PasswordHash.Hash(password, lobbyInformation[lobbyId].PasswordSalt!); + + if (!string.IsNullOrEmpty(password) && + lobbyInformation.TryGetValue(lobbyId, out var lobby) && + lobby.PasswordSalt != null) + passwordHash = PasswordHash.Hash(password, lobby.PasswordSalt!); var lobbyCreate = new LobbyRequestHostInfo() { @@ -105,9 +119,16 @@ namespace Lobbies PasswordHash = passwordHash, }; - byte[] messageData = bufferRental.Rent(); - var len = lobbyCreate.Serialize(messageData); - await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); + var messageData = bufferRental.Rent(); + try + { + var len = lobbyCreate.Serialize(messageData); + await tcpClient.Send(messageData, 0, len); + } + finally + { + bufferRental.Return(messageData); + } }); } @@ -134,21 +155,48 @@ namespace Lobbies Task.Run(() => { byte[]? passwordHash = null; - if (!string.IsNullOrEmpty(password) && lobbyInformation.ContainsKey(lobbyId) && lobbyInformation[lobbyId].PasswordSalt != null) - passwordHash = PasswordHash.Hash(password, lobbyInformation[lobbyId].PasswordSalt!); + if (!string.IsNullOrEmpty(password) && + lobbyInformation.TryGetValue(lobbyId, out var lobby) && + lobby.PasswordSalt != null) + { + passwordHash = PasswordHash.Hash(password, lobby.PasswordSalt); + } + + if (!QueryExternalIpAndPort(sendUdpToGetExternalPortMappingCallback, out var ip, out var mappedPort)) + { + events.Enqueue(new LobbyClientEvent + { + EventType = LobbyClientEventTypes.Failed, + EventData = new LobbyClientDisconnectReason + { + WasError = true, + ErrorMessage = "Could not determine external IP/port." + } + }); + return; + } - QueryExternalIpAndPort(sendUdpToGetExternalPortMappingCallback); var lobbyRequestNatPunch = new LobbyRequestNatPunch() { LobbyId = lobbyId, PasswordHash = passwordHash, - ClientIp = externalIp, - ClientPort = externalPort + ClientIp = ip, + ClientPort = mappedPort }; byte[] messageData = bufferRental.Rent(); var len = lobbyRequestNatPunch.Serialize(messageData); - _ = Task.Run(async () => { await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); }); + _ = Task.Run(async () => + { + try + { + await tcpClient.Send(messageData, 0, len); + } + finally + { + bufferRental.Return(messageData); + } + }); }); } @@ -164,31 +212,65 @@ namespace Lobbies /// after that the udp client will be disposed. public int RequestLobbyNatPunch(Guid lobbyId, string? password, int port = 0) { - if(port < 0 && port > 65535) + if(port < 0 || port > 65535) throw new ArgumentOutOfRangeException(nameof(port)); - var udpEchoServer = new UdpEchoServer(); - udpEchoServer.Start(port); - Task.Run(() => { - byte[]? passwordHash = null; - if (!string.IsNullOrEmpty(password) && lobbyInformation.ContainsKey(lobbyId) && lobbyInformation[lobbyId].PasswordSalt != null) - passwordHash = PasswordHash.Hash(password, lobbyInformation[lobbyId].PasswordSalt!); + var udpEchoServer = new UdpEchoServer(); - QueryExternalIpAndPort(udpEchoServer.Send); - udpEchoServer.Dispose(); - var lobbyRequestNatPunch = new LobbyRequestNatPunch() + try + { + udpEchoServer.Start(port); + + byte[]? passwordHash = null; + if (!string.IsNullOrEmpty(password) && + lobbyInformation.TryGetValue(lobbyId, out var lobby) && + lobby.PasswordSalt != null) + { + passwordHash = PasswordHash.Hash(password, lobby.PasswordSalt); + } + + if (!QueryExternalIpAndPort(udpEchoServer.Send, out var ip, out var mappedPort)) + { + events.Enqueue(new LobbyClientEvent + { + EventType = LobbyClientEventTypes.Failed, + EventData = new LobbyClientDisconnectReason + { + WasError = true, + ErrorMessage = "Could not determine external IP/port." + } + }); + return; + } + + var lobbyRequestNatPunch = new LobbyRequestNatPunch() + { + LobbyId = lobbyId, + PasswordHash = passwordHash, + ClientIp = ip, + ClientPort = mappedPort + }; + + byte[] messageData = bufferRental.Rent(); + var len = lobbyRequestNatPunch.Serialize(messageData); + _ = Task.Run(async () => + { + try + { + await tcpClient.Send(messageData, 0, len); + } + finally + { + bufferRental.Return(messageData); + } + }); + } + finally { - LobbyId = lobbyId, - PasswordHash = passwordHash, - ClientIp = externalIp, - ClientPort = externalPort - }; - - byte[] messageData = bufferRental.Rent(); - var len = lobbyRequestNatPunch.Serialize(messageData); - _ = Task.Run(async () => { await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); }); + udpEchoServer.Dispose(); + } }); return udpEchoServer.Port; @@ -216,9 +298,13 @@ namespace Lobbies HostTryPort = udpEchoServer.Port }; - byte[] messageData = bufferRental.Rent(); + byte[] messageData = bufferRental.Rent(); var len = lobbyUpdate.Serialize(messageData); - _ = Task.Run(async () => { await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); }); + _ = Task.Run(async () => { + try { await tcpClient.Send(messageData, 0, len); } + finally { bufferRental.Return(messageData); } + }); + } public void CloseLobby() @@ -232,7 +318,9 @@ namespace Lobbies byte[] messageData = bufferRental.Rent(); var len = lobbyDelete.Serialize(messageData); - _ = Task.Run(async () => { await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); }); + _ = Task.Run(async () => { + try { await tcpClient.Send(messageData, 0, len); } finally { bufferRental.Return(messageData); } + }); } public void ObserveLobbies(Guid gameId) @@ -241,7 +329,7 @@ namespace Lobbies byte[] messageData = bufferRental.Rent(); var len = lobbiesObserve.Serialize(messageData); - _ = Task.Run(async () => { await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); }); + _ = Task.Run(async () => { try { await tcpClient.Send(messageData, 0, len); } finally { bufferRental.Return(messageData); } }); } public void StopObservingLobbies() @@ -250,7 +338,7 @@ namespace Lobbies byte[] messageData = bufferRental.Rent(); var len = lobbiesStopObserve.Serialize(messageData); - _ = Task.Run(async () => { await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); }); + _ = Task.Run(async () => { try { await tcpClient.Send(messageData, 0, len); } finally { bufferRental.Return(messageData); } }); } public void NotifyLobbyNatPunchDone(int natPunchId, string externalIp, int externalPort) @@ -259,7 +347,7 @@ namespace Lobbies byte[] messageData = bufferRental.Rent(); var len = lobbyNatPunchDone.Serialize(messageData); - _ = Task.Run(async () => { await tcpClient.Send(messageData, 0, len); bufferRental.Return(messageData); }); + _ = Task.Run(async () => { try { await tcpClient.Send(messageData, 0, len); } finally { bufferRental.Return(messageData); } }); } public Task TryDirectConnection(IPAddress[] ipAddressesToTry, int tryPort) @@ -333,28 +421,60 @@ namespace Lobbies return new IPAddress[0]; } - public void QueryExternalIpAndPort(SendUdpMessageCallback sendUdpCallback) + public bool QueryExternalIpAndPort(SendUdpMessageCallback sendUdpCallback, out string? ip, out int port) { + ip = null; + port = 0; + byte[] messageData = bufferRental.Rent(); try { - waitForExternalIp.Reset(); - var queryExternalPortAndIp = new QueryExternalPortAndIp() { LobbyClientId = connectionId }; - var len = queryExternalPortAndIp.Serialize(messageData); - var ip = GetIPsByName(host!, true, false).First(); - var tries = 0; - do - { - sendUdpCallback(new IPEndPoint(ip, port), messageData, len); - } - while (!waitForExternalIp.WaitOne(100) && tries++ < 100); - } - catch - { + if (sendUdpCallback == null) + throw new ArgumentNullException(nameof(sendUdpCallback)); + if (string.IsNullOrWhiteSpace(host)) + return false; + + if (connectionId <= 0) + return false; + + var ips = GetIPsByName(host, true, false); + if (ips.Length == 0) + return false; + + externalIp = null; + externalPort = 0; + waitForExternalIp.Reset(); + + var queryExternalPortAndIp = new QueryExternalPortAndIp + { + LobbyClientId = connectionId + }; + + int len = queryExternalPortAndIp.Serialize(messageData); + var remoteEndpoint = new IPEndPoint(ips[0], this.port); + + for (int tries = 0; tries < 100; tries++) + { + sendUdpCallback(remoteEndpoint, messageData, len); + + if (waitForExternalIp.WaitOne(100)) + { + if (!string.IsNullOrWhiteSpace(externalIp) && externalPort > 0) + { + ip = externalIp; + port = externalPort; + return true; + } + + return false; + } + } + + return false; } - finally + finally { bufferRental.Return(messageData); } @@ -417,7 +537,7 @@ namespace Lobbies var lobbyDelete = LobbyDelete.Deserialize(data.Span); if (lobbyDelete != null) { - lobbyInformation.Remove(lobbyDelete.Id); + lobbyInformation.TryRemove(lobbyDelete.Id, out var _); events.Enqueue(new LobbyClientEvent { EventType = LobbyClientEventTypes.LobbyDelete, EventData = lobbyDelete }); } } @@ -471,19 +591,47 @@ namespace Lobbies public IEnumerable GatherLocalIpAddresses() { - foreach (NetworkInterface netInterface in NetworkInterface.GetAllNetworkInterfaces()) + foreach (var netInterface in NetworkInterface.GetAllNetworkInterfaces()) { - IPInterfaceProperties ipProps = netInterface.GetIPProperties(); + if (netInterface.OperationalStatus != OperationalStatus.Up) + continue; - foreach (UnicastIPAddressInformation addr in ipProps.UnicastAddresses) + if (netInterface.NetworkInterfaceType == NetworkInterfaceType.Loopback) + continue; + + // Optional: virtuelle Interfaces rausfiltern (je nach Bedarf) + if (netInterface.Description.Contains("Virtual", StringComparison.OrdinalIgnoreCase) || + netInterface.Description.Contains("VMware", StringComparison.OrdinalIgnoreCase) || + netInterface.Description.Contains("Docker", StringComparison.OrdinalIgnoreCase)) + continue; + + var ipProps = netInterface.GetIPProperties(); + + foreach (var addr in ipProps.UnicastAddresses) { - yield return addr.Address; - } + var ip = addr.Address; + + // Nur IPv4 + if (ip.AddressFamily != AddressFamily.InterNetwork) + continue; + + // Loopback nochmal sicherheitshalber + if (IPAddress.IsLoopback(ip)) + continue; + + // Link-local (APIPA) raus (169.254.x.x) + if (ip.GetAddressBytes()[0] == 169 && ip.GetAddressBytes()[1] == 254) + continue; + + yield return ip; + } } } public void Dispose() { + Stop(); + connectionCancellationRegistration.Dispose(); waitForExternalIp.Dispose(); tcpClient.Dispose(); udpEchoServer.Dispose(); diff --git a/LobbyClient/TcpClient.cs b/LobbyClient/TcpClient.cs index 5968dff..edad351 100644 --- a/LobbyClient/TcpClient.cs +++ b/LobbyClient/TcpClient.cs @@ -1,4 +1,5 @@ using System; +using System.IO; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; @@ -13,127 +14,185 @@ namespace Lobbies internal delegate void DisconnectedEventArgs(bool clean, string error); internal event DisconnectedEventArgs? Disconnected; - internal delegate void ConnectedEventArgs(); internal event ConnectedEventArgs? Connected; - TcpClient? tcpClient; - NetworkStream? networkStream; - CancellationTokenSource? cancellationTokenSource = new CancellationTokenSource(); - bool running = false; + private const int HeaderSize = 4; + private const int MaxMessageSize = 4096; + private static readonly TimeSpan PingInterval = TimeSpan.FromSeconds(10); + + private TcpClient? tcpClient; + private NetworkStream? networkStream; + private CancellationTokenSource? cancellationTokenSource = new CancellationTokenSource(); + private bool running = false; + private readonly SemaphoreSlim sendLock = new SemaphoreSlim(1, 1); internal async Task Connect(string host, int port) { - bool wasError = false; + bool cleanDisconnect = true; string error = string.Empty; + Task? pingTask = null; try { - cancellationTokenSource!.Token.ThrowIfCancellationRequested(); + if (cancellationTokenSource == null || cancellationTokenSource.IsCancellationRequested) + cancellationTokenSource = new CancellationTokenSource(); - running = true; + var token = cancellationTokenSource.Token; + + token.ThrowIfCancellationRequested(); + + running = true; tcpClient = new TcpClient(); - using (cancellationTokenSource!.Token.Register(() => { tcpClient.Close(); })) + + using (token.Register(() => + { + try { tcpClient.Close(); } catch { } + })) { await tcpClient.ConnectAsync(host, port); } + networkStream = tcpClient.GetStream(); - Memory buffer = new byte[4096]; - Memory target = new byte[4096]; - - int currentOffset = 0; - int currentMessageRemainingLength = 0; - int currentMessageLength = 0; - bool validMessage = true; - int currentReadOffset = 0; - bool offsetSizeInt = false; Connected?.Invoke(); - while (running) + pingTask = Task.Run(() => PingLoop(token), token); + + Memory buffer = new byte[MaxMessageSize]; + Memory target = new byte[MaxMessageSize]; + + int bufferedBytes = 0; + int currentMessageLength = -1; + int currentMessageOffset = 0; + + while (running && !token.IsCancellationRequested) { - int copyOffset = 0; - int receivedBytes = currentReadOffset; + if (bufferedBytes == buffer.Length) + throw new InvalidDataException("Receive buffer overflow."); - if (currentReadOffset < 4) + int bytesRead = await networkStream.ReadAsync(buffer.Slice(bufferedBytes), token); + if (bytesRead == 0) + break; + + bufferedBytes += bytesRead; + + while (true) { - receivedBytes += await networkStream.ReadAsync(buffer.Slice(currentReadOffset), cancellationTokenSource.Token) + currentReadOffset; - } - - if (receivedBytes == 0 && running && !cancellationTokenSource.Token.IsCancellationRequested) - { - throw new Exception("Connection lost!"); - } - - if (receivedBytes > 3 || (currentMessageRemainingLength > 0 && receivedBytes > currentMessageRemainingLength)) - { - currentReadOffset = 0; - if (currentMessageLength == 0) + if (currentMessageLength < 0) { - currentMessageRemainingLength = BitConverter.ToInt32(buffer.Span); - currentMessageLength = currentMessageRemainingLength; - receivedBytes -= 4; - copyOffset += 4; - offsetSizeInt = true; - } - else - offsetSizeInt = false; + if (bufferedBytes < HeaderSize) + break; - var receivedCount = Math.Min(receivedBytes, currentMessageRemainingLength); - receivedBytes -= receivedCount; - copyOffset += receivedCount; + currentMessageLength = BitConverter.ToInt32(buffer.Span.Slice(0, HeaderSize)); - if (validMessage && currentOffset + receivedCount > 0) - { - if (currentOffset + receivedCount < target.Length) - buffer.Slice(offsetSizeInt ? 4 : 0, receivedCount).CopyTo(target.Slice(currentOffset)); - else - validMessage = false; - } + if (currentMessageLength < 0) + throw new InvalidDataException("Negative message length received."); - currentOffset += receivedCount; - currentMessageRemainingLength -= receivedCount; + if (currentMessageLength > MaxMessageSize) + throw new InvalidDataException($"Message too large: {currentMessageLength} > {MaxMessageSize}."); - if (currentMessageRemainingLength <= 0) - { - if (validMessage) - DataReceived?.Invoke(currentMessageLength, target); + if (bufferedBytes > HeaderSize) + buffer.Slice(HeaderSize, bufferedBytes - HeaderSize).CopyTo(buffer); - if (receivedBytes > 0) + bufferedBytes -= HeaderSize; + currentMessageOffset = 0; + + if (currentMessageLength == 0) { - buffer.Slice(copyOffset, receivedBytes).CopyTo(buffer); - currentReadOffset += receivedBytes; + currentMessageLength = -1; + continue; } - - currentOffset = 0; - currentMessageLength = 0; - currentMessageRemainingLength = 0; - validMessage = true; } - } - else if (receivedBytes > 0) - { - currentReadOffset += receivedBytes; + + int remainingMessageBytes = currentMessageLength - currentMessageOffset; + if (remainingMessageBytes <= 0) + { + DataReceived?.Invoke(currentMessageLength, target.Slice(0, currentMessageLength)); + currentMessageLength = -1; + currentMessageOffset = 0; + continue; + } + + if (bufferedBytes == 0) + break; + + int chunkSize = Math.Min(bufferedBytes, remainingMessageBytes); + + buffer.Slice(0, chunkSize).CopyTo(target.Slice(currentMessageOffset)); + currentMessageOffset += chunkSize; + + if (bufferedBytes > chunkSize) + buffer.Slice(chunkSize, bufferedBytes - chunkSize).CopyTo(buffer); + + bufferedBytes -= chunkSize; + + if (currentMessageOffset == currentMessageLength) + { + DataReceived?.Invoke(currentMessageLength, target.Slice(0, currentMessageLength)); + currentMessageLength = -1; + currentMessageOffset = 0; + } } } } - catch(Exception e) + catch (OperationCanceledException) { + cleanDisconnect = true; + } + catch (Exception e) + { + cleanDisconnect = false; error = e.Message; } - finally + finally { - wasError = running; - running = false; - - networkStream?.Dispose(); - tcpClient?.Dispose(); - tcpClient = null; + try { cancellationTokenSource?.Cancel(); } catch { } + + if (pingTask != null) + { + try { await pingTask; } catch { } + } + + try { networkStream?.Dispose(); } catch { } + try { tcpClient?.Close(); } catch { } + try { tcpClient?.Dispose(); } catch { } + networkStream = null; + tcpClient = null; - Disconnected?.Invoke(!wasError, error); + Disconnected?.Invoke(cleanDisconnect, error); + } + } + + private async Task PingLoop(CancellationToken token) + { + while (running && !token.IsCancellationRequested) + { + await Task.Delay(PingInterval, token); + + if (!running || token.IsCancellationRequested) + break; + + await SendPing(token); + } + } + + private async Task SendPing(CancellationToken token) + { + if (!running || networkStream == null) + return; + + await sendLock.WaitAsync(token); + try + { + await networkStream.WriteAsync(BitConverter.GetBytes(0), 0, 4, token); + } + finally + { + sendLock.Release(); } } @@ -141,10 +200,18 @@ namespace Lobbies { try { - if (running && networkStream != null) + if (!running || networkStream == null || cancellationTokenSource == null) + return; + + await sendLock.WaitAsync(cancellationTokenSource.Token); + try { - await networkStream.WriteAsync(BitConverter.GetBytes(count - offset), 0, 4, cancellationTokenSource!.Token); - await networkStream.WriteAsync(buffer, offset, count, cancellationTokenSource!.Token); + await networkStream.WriteAsync(BitConverter.GetBytes(count), 0, 4, cancellationTokenSource.Token); + await networkStream.WriteAsync(buffer, offset, count, cancellationTokenSource.Token); + } + finally + { + sendLock.Release(); } } catch { } @@ -158,9 +225,17 @@ namespace Lobbies public void Dispose() { - cancellationTokenSource?.Cancel(); - tcpClient?.Dispose(); - cancellationTokenSource?.Dispose(); + running = false; + + try { cancellationTokenSource?.Cancel(); } catch { } + try { networkStream?.Dispose(); } catch { } + try { tcpClient?.Close(); } catch { } + try { tcpClient?.Dispose(); } catch { } + try { cancellationTokenSource?.Dispose(); } catch { } + + networkStream = null; + tcpClient = null; + cancellationTokenSource = null; } } -} +} \ No newline at end of file diff --git a/LobbyClientTest/Program.cs b/LobbyClientTest/Program.cs index 35ac64d..83c51d5 100644 --- a/LobbyClientTest/Program.cs +++ b/LobbyClientTest/Program.cs @@ -25,8 +25,10 @@ _ = Task.Run(() => { while (running) { - foreach (var lobbyEvent in lobbyClient.ReadEvents(20)) + while (lobbyClient.TryReadEvent(out var lobbyEvent)) { + if (lobbyEvent == null) continue; + switch (lobbyEvent.EventType) { case LobbyClientEventTypes.LobbyJoinFailed: @@ -169,7 +171,7 @@ _ = Task.Run(() => lobbyClient.QueryExternalIpAndPort((remoteEndpoint, messageData, messageLength) => { fakeGameHost.Send(remoteEndpoint, messageData, messageLength); - }); + }, out var ip, out var port); var ep = new IPEndPoint(IPAddress.Parse(lobbyRequestNatPunch.ClientIp!), lobbyRequestNatPunch.ClientPort); for (int z = 0; z < 32; z++) diff --git a/LobbyServer/TcpLobbyServer.cs b/LobbyServer/TcpLobbyServer.cs index 55e4977..4a407af 100644 --- a/LobbyServer/TcpLobbyServer.cs +++ b/LobbyServer/TcpLobbyServer.cs @@ -18,11 +18,18 @@ namespace LobbyServer internal CancellationTokenSource? cancellationToken = null; internal NetworkStream? stream; internal TcpClient? client; - + internal DateTime lastSeenUtc = DateTime.UtcNow; + public void Dispose() { - cancellationToken?.Dispose(); + try { stream?.Dispose(); } catch { } + try { client?.Close(); } catch { } + try { client?.Dispose(); } catch { } + try { cancellationToken?.Dispose(); } catch { } + cancellationToken = null; + stream = null; + client = null; } } @@ -38,7 +45,8 @@ namespace LobbyServer cancellationToken.TryReset(); running = true; clientIdCounter = 0; - _ = Task.Run(() => Listener(port)); + _ = Task.Run(() => MonitorClients()); + _ = Task.Run(() => Listener(port)); } public void Stop() @@ -55,7 +63,6 @@ namespace LobbyServer catch { } } - activeClients.Clear (); serverClosed.WaitOne(); } @@ -97,7 +104,7 @@ namespace LobbyServer { if (activeClients.TryGetValue(clientId, out var lobbyClient) && lobbyClient.stream != null && lobbyClient.cancellationToken != null) { - await lobbyClient.stream.WriteAsync(BitConverter.GetBytes(count - offset), 0, 4, lobbyClient.cancellationToken.Token); + await lobbyClient.stream.WriteAsync(BitConverter.GetBytes(count), 0, 4, lobbyClient.cancellationToken.Token); await lobbyClient.stream.WriteAsync(buffer, offset, count, lobbyClient.cancellationToken.Token); } } @@ -107,113 +114,148 @@ namespace LobbyServer private async Task ClientThread(TcpClient client) { int myId = Interlocked.Increment(ref clientIdCounter); + Client? lobbyClient = null; try - { - await using NetworkStream stream = client.GetStream(); + { + var stream = client.GetStream(); Memory buffer = new byte[4096]; Memory target = new byte[4096]; - using var lobbyClient = new Client + lobbyClient = new Client { cancellationToken = new CancellationTokenSource(), stream = stream, - client = client + client = client, + lastSeenUtc = DateTime.UtcNow }; activeClients.TryAdd(myId, lobbyClient); - - int currentOffset = 0; - int currentMessageRemainingLength = 0; - int currentMessageLength = 0; + + int bufferedBytes = 0; + int currentMessageLength = -1; + int currentMessageOffset = 0; bool validMessage = true; - bool offsetSizeInt = false; - int currentReadOffset = 0; var lobbyClientConnectionInfo = new LobbyClientConnectionInfo { Id = myId }; byte[] sendBuffer = new byte[128]; int sendLen = lobbyClientConnectionInfo.Serialize(sendBuffer); await Send(myId, sendBuffer, 0, sendLen); - while (running) + while (running && !lobbyClient.cancellationToken.Token.IsCancellationRequested) { - int copyOffset = 0; - int receivedBytes = currentReadOffset; + int bytesRead = await stream.ReadAsync(buffer.Slice(bufferedBytes), lobbyClient.cancellationToken.Token); + if (bytesRead == 0) + break; - if (currentReadOffset < 4) + bufferedBytes += bytesRead; + lobbyClient.lastSeenUtc = DateTime.UtcNow; + + while (true) { - receivedBytes += await stream.ReadAsync(buffer.Slice(currentReadOffset), lobbyClient.cancellationToken.Token) + currentReadOffset; - if (receivedBytes == 0) - throw new Exception("Connection lost!"); - } - - if (receivedBytes > 3 || (currentMessageRemainingLength > 0 && receivedBytes >= currentMessageRemainingLength)) - { - currentReadOffset = 0; - if (currentMessageLength == 0) + if (currentMessageLength < 0) { - currentMessageRemainingLength = BitConverter.ToInt32(buffer.Span); - currentMessageLength = currentMessageRemainingLength; - receivedBytes -= 4; - copyOffset += 4; - offsetSizeInt = true; - } - else - offsetSizeInt = false; + if (bufferedBytes < 4) + break; - var receivedCount = Math.Min(receivedBytes, currentMessageRemainingLength); + currentMessageLength = BitConverter.ToInt32(buffer.Span.Slice(0, 4)); - receivedBytes -= receivedCount; - copyOffset += receivedCount; + if (currentMessageLength < 0) + throw new InvalidDataException("Negative message length received."); - if (validMessage && currentOffset + receivedCount > 0) - { - if (currentOffset + receivedCount < target.Length) - buffer.Slice(offsetSizeInt ? 4 : 0, receivedCount).CopyTo(target.Slice(currentOffset)); - else - validMessage = false; - } + buffer.Slice(4, bufferedBytes - 4).CopyTo(buffer); + bufferedBytes -= 4; + currentMessageOffset = 0; + validMessage = currentMessageLength <= target.Length; - currentOffset += receivedCount; - currentMessageRemainingLength -= receivedCount; - - if (currentMessageRemainingLength <= 0) - { - if(validMessage) - DataReceived?.Invoke(myId, currentMessageLength, target); - - if (receivedBytes > 0) + if (currentMessageLength == 0) { - buffer.Slice(copyOffset, receivedBytes).CopyTo(buffer); - currentReadOffset += receivedBytes; + currentMessageLength = -1; + continue; + } + } + + if (bufferedBytes == 0) + break; + + int chunkSize = Math.Min(bufferedBytes, currentMessageLength - currentMessageOffset); + + if (chunkSize > 0) + { + if (validMessage) + { + buffer.Slice(0, chunkSize).CopyTo(target.Slice(currentMessageOffset)); } - currentOffset = 0; - currentMessageLength = 0; - currentMessageRemainingLength = 0; - validMessage = true; + currentMessageOffset += chunkSize; + + buffer.Slice(chunkSize, bufferedBytes - chunkSize).CopyTo(buffer); + bufferedBytes -= chunkSize; } + + if (currentMessageOffset < currentMessageLength) + break; + + if (validMessage) + { + DataReceived?.Invoke(myId, currentMessageLength, target.Slice(0, currentMessageLength)); + } + + currentMessageLength = -1; + currentMessageOffset = 0; + validMessage = true; } - else if(receivedBytes > 0) - { - currentReadOffset += receivedBytes; - } + + if (bufferedBytes == buffer.Length && currentMessageLength < 0) + throw new InvalidDataException("Receive buffer overflow while waiting for message header."); } } finally { + activeClients.TryRemove(myId, out _); + lobbyClient?.Dispose(); ClientDisconnected?.Invoke(myId); - activeClients.TryRemove(myId, out var _); - client?.Dispose(); } } - internal string? GetClientIp(int cliendId) + private async Task MonitorClients() { try { - if (activeClients.TryGetValue(cliendId, out var client)) + while (running) + { + var now = DateTime.UtcNow; + + foreach (var kv in activeClients) + { + var id = kv.Key; + var client = kv.Value; + + if ((now - client.lastSeenUtc) > TimeSpan.FromSeconds(30)) + { + try + { + client.cancellationToken?.Cancel(); + client.client?.Close(); + } + catch { } + } + } + + await Task.Delay(1000, cancellationToken.Token); + } + } + catch (OperationCanceledException) + { + } + } + + internal string? GetClientIp(int cliendIp) + { + try + { + if (activeClients.TryGetValue(cliendIp, out var client)) { return (client.client!.Client.RemoteEndPoint as IPEndPoint)!.Address.ToString(); }