添加连接管理的互斥锁,增强多线程环境下的连接安全性;更新TCPClient以支持断开回调

This commit is contained in:
Ayndpa
2025-11-18 19:20:36 +08:00
parent b308a644ff
commit 676d39d6a2
5 changed files with 44 additions and 13 deletions

View File

@@ -48,6 +48,7 @@ int g_currentVirtualPort = 0;
std::vector<HSteamNetConnection> connections; std::vector<HSteamNetConnection> connections;
std::map<HSteamNetConnection, TCPClient*> clientMap; std::map<HSteamNetConnection, TCPClient*> clientMap;
std::mutex clientMutex; std::mutex clientMutex;
std::mutex connectionsMutex; // Add mutex for connections
int localPort = 0; int localPort = 0;
bool g_isHost = false; bool g_isHost = false;
bool g_isClient = false; bool g_isClient = false;
@@ -73,7 +74,10 @@ void OnSteamNetConnectionStatusChanged(SteamNetConnectionStatusChangedCallback_t
{ {
// Incoming connection, accept it // Incoming connection, accept it
SteamNetworkingSockets()->AcceptConnection(pInfo->m_hConn); SteamNetworkingSockets()->AcceptConnection(pInfo->m_hConn);
connections.push_back(pInfo->m_hConn); {
std::lock_guard<std::mutex> lockConn(connectionsMutex);
connections.push_back(pInfo->m_hConn);
}
g_hConnection = pInfo->m_hConn; // Keep for backward compatibility if needed g_hConnection = pInfo->m_hConn; // Keep for backward compatibility if needed
g_isConnected = true; g_isConnected = true;
std::cout << "Accepted incoming connection from " << pInfo->m_info.m_identityRemote.GetSteamID().ConvertToUint64() << std::endl; std::cout << "Accepted incoming connection from " << pInfo->m_info.m_identityRemote.GetSteamID().ConvertToUint64() << std::endl;
@@ -115,12 +119,15 @@ void OnSteamNetConnectionStatusChanged(SteamNetConnectionStatusChangedCallback_t
g_isConnected = false; g_isConnected = false;
g_hConnection = k_HSteamNetConnection_Invalid; g_hConnection = k_HSteamNetConnection_Invalid;
// Remove from connections // Remove from connections
auto it = connections.begin(); {
while (it != connections.end()) { std::lock_guard<std::mutex> lockConn(connectionsMutex);
if (*it == pInfo->m_hConn) { auto it = connections.begin();
it = connections.erase(it); while (it != connections.end()) {
} else { if (*it == pInfo->m_hConn) {
++it; it = connections.erase(it);
} else {
++it;
}
} }
} }
// Remove from userMap // Remove from userMap
@@ -198,7 +205,7 @@ int main() {
boost::asio::io_context io_context; boost::asio::io_context io_context;
// Create Steam Message Handler // Create Steam Message Handler
SteamMessageHandler messageHandler(io_context, m_pInterface, connections, clientMap, clientMutex, server, g_isHost, localPort); SteamMessageHandler messageHandler(io_context, m_pInterface, connections, clientMap, clientMutex, connectionsMutex, server, g_isHost, localPort);
// Initialize GLFW // Initialize GLFW
if (!glfwInit()) { if (!glfwInit()) {
@@ -344,6 +351,7 @@ int main() {
} }
{ {
std::lock_guard<std::mutex> lock(clientMutex); std::lock_guard<std::mutex> lock(clientMutex);
std::lock_guard<std::mutex> lockConn(connectionsMutex);
ImGui::Text("连接的好友: %d", (int)connections.size()); ImGui::Text("连接的好友: %d", (int)connections.size());
ImGui::Text("活跃的TCP客户端: %d", (int)clientMap.size()); ImGui::Text("活跃的TCP客户端: %d", (int)clientMap.size());
} }

View File

@@ -10,8 +10,8 @@
const char* CONTROL_PREFIX = "CONTROL:"; const char* CONTROL_PREFIX = "CONTROL:";
const size_t CONTROL_PREFIX_LEN = 8; const size_t CONTROL_PREFIX_LEN = 8;
SteamMessageHandler::SteamMessageHandler(boost::asio::io_context& io_context, ISteamNetworkingSockets* interface, std::vector<HSteamNetConnection>& connections, std::map<HSteamNetConnection, TCPClient*>& clientMap, std::mutex& clientMutex, std::unique_ptr<TCPServer>& server, bool& g_isHost, int& localPort) SteamMessageHandler::SteamMessageHandler(boost::asio::io_context& io_context, ISteamNetworkingSockets* interface, std::vector<HSteamNetConnection>& connections, std::map<HSteamNetConnection, TCPClient*>& clientMap, std::mutex& clientMutex, std::mutex& connectionsMutex, std::unique_ptr<TCPServer>& server, bool& g_isHost, int& localPort)
: io_context_(io_context), m_pInterface_(interface), connections_(connections), clientMap_(clientMap), clientMutex_(clientMutex), server_(server), g_isHost_(g_isHost), localPort_(localPort), running_(false) {} : io_context_(io_context), m_pInterface_(interface), connections_(connections), clientMap_(clientMap), clientMutex_(clientMutex), connectionsMutex_(connectionsMutex), server_(server), g_isHost_(g_isHost), localPort_(localPort), running_(false) {}
SteamMessageHandler::~SteamMessageHandler() { SteamMessageHandler::~SteamMessageHandler() {
stop(); stop();
@@ -48,8 +48,13 @@ void SteamMessageHandler::run() {
} }
void SteamMessageHandler::pollMessages() { void SteamMessageHandler::pollMessages() {
std::vector<HSteamNetConnection> currentConnections;
{
std::lock_guard<std::mutex> lockConn(connectionsMutex_);
currentConnections = connections_;
}
std::lock_guard<std::mutex> lock(clientMutex_); std::lock_guard<std::mutex> lock(clientMutex_);
for (auto conn : connections_) { for (auto conn : currentConnections) {
ISteamNetworkingMessage* pIncomingMsgs[10]; ISteamNetworkingMessage* pIncomingMsgs[10];
int numMsgs = m_pInterface_->ReceiveMessagesOnConnection(conn, pIncomingMsgs, 10); int numMsgs = m_pInterface_->ReceiveMessagesOnConnection(conn, pIncomingMsgs, 10);
for (int i = 0; i < numMsgs; ++i) { for (int i = 0; i < numMsgs; ++i) {
@@ -72,6 +77,11 @@ void SteamMessageHandler::pollMessages() {
std::lock_guard<std::mutex> lock(clientMutex_); std::lock_guard<std::mutex> lock(clientMutex_);
m_pInterface_->SendMessageToConnection(conn, data, size, k_nSteamNetworkingSend_Reliable, nullptr); m_pInterface_->SendMessageToConnection(conn, data, size, k_nSteamNetworkingSend_Reliable, nullptr);
}); });
client->setDisconnectCallback([conn, this]() {
std::lock_guard<std::mutex> lock(clientMutex_);
m_pInterface_->CloseConnection(conn, 0, nullptr, false);
std::cout << "Closed Steam connection due to TCP client disconnect" << std::endl;
});
clientMap_[conn] = client; clientMap_[conn] = client;
std::cout << "Created TCP Client for connection on first message" << std::endl; std::cout << "Created TCP Client for connection on first message" << std::endl;
} else { } else {

View File

@@ -13,7 +13,7 @@
class SteamMessageHandler { class SteamMessageHandler {
public: public:
SteamMessageHandler(boost::asio::io_context& io_context, ISteamNetworkingSockets* interface, std::vector<HSteamNetConnection>& connections, std::map<HSteamNetConnection, TCPClient*>& clientMap, std::mutex& clientMutex, std::unique_ptr<TCPServer>& server, bool& g_isHost, int& localPort); SteamMessageHandler(boost::asio::io_context& io_context, ISteamNetworkingSockets* interface, std::vector<HSteamNetConnection>& connections, std::map<HSteamNetConnection, TCPClient*>& clientMap, std::mutex& clientMutex, std::mutex& connectionsMutex, std::unique_ptr<TCPServer>& server, bool& g_isHost, int& localPort);
~SteamMessageHandler(); ~SteamMessageHandler();
void start(); void start();
@@ -28,6 +28,7 @@ private:
std::vector<HSteamNetConnection>& connections_; std::vector<HSteamNetConnection>& connections_;
std::map<HSteamNetConnection, TCPClient*>& clientMap_; std::map<HSteamNetConnection, TCPClient*>& clientMap_;
std::mutex& clientMutex_; std::mutex& clientMutex_;
std::mutex& connectionsMutex_;
std::unique_ptr<TCPServer>& server_; std::unique_ptr<TCPServer>& server_;
bool& g_isHost_; bool& g_isHost_;
int& localPort_; int& localPort_;

View File

@@ -1,7 +1,7 @@
#include "tcp_client.h" #include "tcp_client.h"
#include <iostream> #include <iostream>
TCPClient::TCPClient(const std::string& host, int port) : host_(host), port_(port), connected_(false), socket_(std::make_shared<tcp::socket>(io_context_)), work_(boost::asio::make_work_guard(io_context_)), buffer_(1024) {} TCPClient::TCPClient(const std::string& host, int port) : host_(host), port_(port), connected_(false), disconnected_(false), socket_(std::make_shared<tcp::socket>(io_context_)), work_(boost::asio::make_work_guard(io_context_)), buffer_(1024) {}
TCPClient::~TCPClient() { disconnect(); } TCPClient::~TCPClient() { disconnect(); }
@@ -26,6 +26,11 @@ bool TCPClient::connect() {
} }
void TCPClient::disconnect() { void TCPClient::disconnect() {
if (disconnected_) return;
disconnected_ = true;
if (disconnectCallback_) {
disconnectCallback_();
}
connected_ = false; connected_ = false;
io_context_.stop(); io_context_.stop();
if (clientThread_.joinable()) { if (clientThread_.joinable()) {
@@ -59,6 +64,10 @@ void TCPClient::setReceiveCallback(std::function<void(const char*, size_t)> call
receiveCallbackBytes_ = callback; receiveCallbackBytes_ = callback;
} }
void TCPClient::setDisconnectCallback(std::function<void()> callback) {
disconnectCallback_ = callback;
}
void TCPClient::start_read() { void TCPClient::start_read() {
socket_->async_read_some(boost::asio::buffer(buffer_), [this](const boost::system::error_code& error, std::size_t bytes_transferred) { socket_->async_read_some(boost::asio::buffer(buffer_), [this](const boost::system::error_code& error, std::size_t bytes_transferred) {
handle_read(error, bytes_transferred); handle_read(error, bytes_transferred);

View File

@@ -21,6 +21,7 @@ public:
void send(const char* data, size_t size); void send(const char* data, size_t size);
void setReceiveCallback(std::function<void(const std::string&)> callback); void setReceiveCallback(std::function<void(const std::string&)> callback);
void setReceiveCallback(std::function<void(const char*, size_t)> callback); void setReceiveCallback(std::function<void(const char*, size_t)> callback);
void setDisconnectCallback(std::function<void()> callback);
private: private:
void start_read(); void start_read();
@@ -29,6 +30,7 @@ private:
std::string host_; std::string host_;
int port_; int port_;
bool connected_; bool connected_;
bool disconnected_;
boost::asio::io_context io_context_; boost::asio::io_context io_context_;
boost::asio::executor_work_guard<boost::asio::io_context::executor_type> work_; boost::asio::executor_work_guard<boost::asio::io_context::executor_type> work_;
std::shared_ptr<tcp::socket> socket_; std::shared_ptr<tcp::socket> socket_;
@@ -36,5 +38,6 @@ private:
std::mutex socketMutex_; std::mutex socketMutex_;
std::function<void(const std::string&)> receiveCallback_; std::function<void(const std::string&)> receiveCallback_;
std::function<void(const char*, size_t)> receiveCallbackBytes_; std::function<void(const char*, size_t)> receiveCallbackBytes_;
std::function<void()> disconnectCallback_;
std::vector<char> buffer_; std::vector<char> buffer_;
}; };