emsApplication/sdk/include/hv/TcpServer.h

323 lines
9.7 KiB
C
Raw Permalink Normal View History

2024-09-10 16:12:38 +08:00
#ifndef HV_TCP_SERVER_HPP_
#define HV_TCP_SERVER_HPP_
#include "hsocket.h"
#include "hssl.h"
#include "hlog.h"
#include "EventLoopThreadPool.h"
#include "Channel.h"
namespace hv {
template<class TSocketChannel = SocketChannel>
class TcpServerEventLoopTmpl {
public:
typedef std::shared_ptr<TSocketChannel> TSocketChannelPtr;
TcpServerEventLoopTmpl(EventLoopPtr loop = NULL) {
acceptor_loop = loop ? loop : std::make_shared<EventLoop>();
port = 0;
listenfd = -1;
tls = false;
tls_setting = NULL;
unpack_setting = NULL;
max_connections = 0xFFFFFFFF;
load_balance = LB_RoundRobin;
}
virtual ~TcpServerEventLoopTmpl() {
HV_FREE(tls_setting);
HV_FREE(unpack_setting);
}
EventLoopPtr loop(int idx = -1) {
EventLoopPtr worker_loop = worker_threads.loop(idx);
if (worker_loop == NULL) {
worker_loop = acceptor_loop;
}
return worker_loop;
}
//@retval >=0 listenfd, <0 error
int createsocket(int port, const char* host = "0.0.0.0") {
listenfd = Listen(port, host);
if (listenfd < 0) return listenfd;
this->host = host;
this->port = port;
return listenfd;
}
// closesocket thread-safe
void closesocket() {
if (listenfd >= 0) {
hloop_t* loop = acceptor_loop->loop();
if (loop) {
hio_t* listenio = hio_get(loop, listenfd);
assert(listenio != NULL);
hio_close_async(listenio);
}
listenfd = -1;
}
}
void setMaxConnectionNum(uint32_t num) {
max_connections = num;
}
void setLoadBalance(load_balance_e lb) {
load_balance = lb;
}
// NOTE: totalThreadNum = 1 acceptor_thread + N worker_threads (N can be 0)
void setThreadNum(int num) {
worker_threads.setThreadNum(num);
}
int startAccept() {
if (listenfd < 0) {
listenfd = createsocket(port, host.c_str());
if (listenfd < 0) {
hloge("createsocket %s:%d return %d!\n", host.c_str(), port, listenfd);
return listenfd;
}
}
hloop_t* loop = acceptor_loop->loop();
if (loop == NULL) return -2;
hio_t* listenio = haccept(loop, listenfd, onAccept);
assert(listenio != NULL);
hevent_set_userdata(listenio, this);
if (tls) {
hio_enable_ssl(listenio);
if (tls_setting) {
int ret = hio_new_ssl_ctx(listenio, tls_setting);
if (ret != 0) {
hloge("new SSL_CTX failed: %d", ret);
closesocket();
return ret;
}
}
}
return 0;
}
int stopAccept() {
if (listenfd < 0) return -1;
hloop_t* loop = acceptor_loop->loop();
if (loop == NULL) return -2;
hio_t* listenio = hio_get(loop, listenfd);
assert(listenio != NULL);
return hio_del(listenio, HV_READ);
}
// start thread-safe
void start(bool wait_threads_started = true) {
if (worker_threads.threadNum() > 0) {
worker_threads.start(wait_threads_started);
}
acceptor_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::startAccept, this));
}
// stop thread-safe
void stop(bool wait_threads_stopped = true) {
closesocket();
if (worker_threads.threadNum() > 0) {
worker_threads.stop(wait_threads_stopped);
}
}
int withTLS(hssl_ctx_opt_t* opt = NULL) {
tls = true;
if (opt) {
if (tls_setting == NULL) {
HV_ALLOC_SIZEOF(tls_setting);
}
opt->endpoint = HSSL_SERVER;
*tls_setting = *opt;
}
return 0;
}
void setUnpack(unpack_setting_t* setting) {
if (setting == NULL) {
HV_FREE(unpack_setting);
return;
}
if (unpack_setting == NULL) {
HV_ALLOC_SIZEOF(unpack_setting);
}
*unpack_setting = *setting;
}
// channel
const TSocketChannelPtr& addChannel(hio_t* io) {
uint32_t id = hio_id(io);
auto channel = std::make_shared<TSocketChannel>(io);
std::lock_guard<std::mutex> locker(mutex_);
channels[id] = channel;
return channels[id];
}
TSocketChannelPtr getChannelById(uint32_t id) {
std::lock_guard<std::mutex> locker(mutex_);
auto iter = channels.find(id);
return iter != channels.end() ? iter->second : NULL;
}
void removeChannel(const TSocketChannelPtr& channel) {
uint32_t id = channel->id();
std::lock_guard<std::mutex> locker(mutex_);
channels.erase(id);
}
size_t connectionNum() {
std::lock_guard<std::mutex> locker(mutex_);
return channels.size();
}
int foreachChannel(std::function<void(const TSocketChannelPtr& channel)> fn) {
std::lock_guard<std::mutex> locker(mutex_);
for (auto& pair : channels) {
fn(pair.second);
}
return channels.size();
}
// broadcast thread-safe
int broadcast(const void* data, int size) {
return foreachChannel([data, size](const TSocketChannelPtr& channel) {
channel->write(data, size);
});
}
int broadcast(const std::string& str) {
return broadcast(str.data(), str.size());
}
private:
static void newConnEvent(hio_t* connio) {
TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
if (server->connectionNum() >= server->max_connections) {
hlogw("over max_connections");
hio_close(connio);
return;
}
// NOTE: attach to worker loop
EventLoop* worker_loop = currentThreadEventLoop;
assert(worker_loop != NULL);
hio_attach(worker_loop->loop(), connio);
const TSocketChannelPtr& channel = server->addChannel(connio);
channel->status = SocketChannel::CONNECTED;
channel->onread = [server, &channel](Buffer* buf) {
if (server->onMessage) {
server->onMessage(channel, buf);
}
};
channel->onwrite = [server, &channel](Buffer* buf) {
if (server->onWriteComplete) {
server->onWriteComplete(channel, buf);
}
};
channel->onclose = [server, &channel]() {
EventLoop* worker_loop = currentThreadEventLoop;
assert(worker_loop != NULL);
--worker_loop->connectionNum;
channel->status = SocketChannel::CLOSED;
if (server->onConnection) {
server->onConnection(channel);
}
server->removeChannel(channel);
// NOTE: After removeChannel, channel may be destroyed,
// so in this lambda function, no code should be added below.
};
if (server->unpack_setting) {
channel->setUnpack(server->unpack_setting);
}
channel->startRead();
if (server->onConnection) {
server->onConnection(channel);
}
}
static void onAccept(hio_t* connio) {
TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
// NOTE: detach from acceptor loop
hio_detach(connio);
EventLoopPtr worker_loop = server->worker_threads.nextLoop(server->load_balance);
if (worker_loop == NULL) {
worker_loop = server->acceptor_loop;
}
++worker_loop->connectionNum;
worker_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::newConnEvent, connio));
}
public:
std::string host;
int port;
int listenfd;
bool tls;
hssl_ctx_opt_t* tls_setting;
unpack_setting_t* unpack_setting;
// Callback
std::function<void(const TSocketChannelPtr&)> onConnection;
std::function<void(const TSocketChannelPtr&, Buffer*)> onMessage;
// NOTE: Use Channel::isWriteComplete in onWriteComplete callback to determine whether all data has been written.
std::function<void(const TSocketChannelPtr&, Buffer*)> onWriteComplete;
uint32_t max_connections;
load_balance_e load_balance;
private:
// id => TSocketChannelPtr
std::map<uint32_t, TSocketChannelPtr> channels; // GUAREDE_BY(mutex_)
std::mutex mutex_;
EventLoopPtr acceptor_loop;
EventLoopThreadPool worker_threads;
};
template<class TSocketChannel = SocketChannel>
class TcpServerTmpl : private EventLoopThread, public TcpServerEventLoopTmpl<TSocketChannel> {
public:
TcpServerTmpl(EventLoopPtr loop = NULL)
: EventLoopThread(loop)
, TcpServerEventLoopTmpl<TSocketChannel>(EventLoopThread::loop())
, is_loop_owner(loop == NULL)
{}
virtual ~TcpServerTmpl() {
stop(true);
}
EventLoopPtr loop(int idx = -1) {
return TcpServerEventLoopTmpl<TSocketChannel>::loop(idx);
}
// start thread-safe
void start(bool wait_threads_started = true) {
TcpServerEventLoopTmpl<TSocketChannel>::start(wait_threads_started);
if (!isRunning()) {
EventLoopThread::start(wait_threads_started);
}
}
// stop thread-safe
void stop(bool wait_threads_stopped = true) {
if (is_loop_owner) {
EventLoopThread::stop(wait_threads_stopped);
}
TcpServerEventLoopTmpl<TSocketChannel>::stop(wait_threads_stopped);
}
private:
bool is_loop_owner;
};
typedef TcpServerTmpl<SocketChannel> TcpServer;
}
#endif // HV_TCP_SERVER_HPP_