323 lines
9.7 KiB
C++
323 lines
9.7 KiB
C++
#ifndef HV_TCP_SERVER_HPP_
|
|
#define HV_TCP_SERVER_HPP_
|
|
|
|
#include "hsocket.h"
|
|
#include "hssl.h"
|
|
#include "hlog.h"
|
|
|
|
#include "EventLoopThreadPool.h"
|
|
#include "Channel.h"
|
|
|
|
namespace hv {
|
|
|
|
template<class TSocketChannel = SocketChannel>
|
|
class TcpServerEventLoopTmpl {
|
|
public:
|
|
typedef std::shared_ptr<TSocketChannel> TSocketChannelPtr;
|
|
|
|
TcpServerEventLoopTmpl(EventLoopPtr loop = NULL) {
|
|
acceptor_loop = loop ? loop : std::make_shared<EventLoop>();
|
|
port = 0;
|
|
listenfd = -1;
|
|
tls = false;
|
|
tls_setting = NULL;
|
|
unpack_setting = NULL;
|
|
max_connections = 0xFFFFFFFF;
|
|
load_balance = LB_RoundRobin;
|
|
}
|
|
|
|
virtual ~TcpServerEventLoopTmpl() {
|
|
HV_FREE(tls_setting);
|
|
HV_FREE(unpack_setting);
|
|
}
|
|
|
|
EventLoopPtr loop(int idx = -1) {
|
|
EventLoopPtr worker_loop = worker_threads.loop(idx);
|
|
if (worker_loop == NULL) {
|
|
worker_loop = acceptor_loop;
|
|
}
|
|
return worker_loop;
|
|
}
|
|
|
|
//@retval >=0 listenfd, <0 error
|
|
int createsocket(int port, const char* host = "0.0.0.0") {
|
|
listenfd = Listen(port, host);
|
|
if (listenfd < 0) return listenfd;
|
|
this->host = host;
|
|
this->port = port;
|
|
return listenfd;
|
|
}
|
|
// closesocket thread-safe
|
|
void closesocket() {
|
|
if (listenfd >= 0) {
|
|
hloop_t* loop = acceptor_loop->loop();
|
|
if (loop) {
|
|
hio_t* listenio = hio_get(loop, listenfd);
|
|
assert(listenio != NULL);
|
|
hio_close_async(listenio);
|
|
}
|
|
listenfd = -1;
|
|
}
|
|
}
|
|
|
|
void setMaxConnectionNum(uint32_t num) {
|
|
max_connections = num;
|
|
}
|
|
|
|
void setLoadBalance(load_balance_e lb) {
|
|
load_balance = lb;
|
|
}
|
|
|
|
// NOTE: totalThreadNum = 1 acceptor_thread + N worker_threads (N can be 0)
|
|
void setThreadNum(int num) {
|
|
worker_threads.setThreadNum(num);
|
|
}
|
|
|
|
int startAccept() {
|
|
if (listenfd < 0) {
|
|
listenfd = createsocket(port, host.c_str());
|
|
if (listenfd < 0) {
|
|
hloge("createsocket %s:%d return %d!\n", host.c_str(), port, listenfd);
|
|
return listenfd;
|
|
}
|
|
}
|
|
hloop_t* loop = acceptor_loop->loop();
|
|
if (loop == NULL) return -2;
|
|
hio_t* listenio = haccept(loop, listenfd, onAccept);
|
|
assert(listenio != NULL);
|
|
hevent_set_userdata(listenio, this);
|
|
if (tls) {
|
|
hio_enable_ssl(listenio);
|
|
if (tls_setting) {
|
|
int ret = hio_new_ssl_ctx(listenio, tls_setting);
|
|
if (ret != 0) {
|
|
hloge("new SSL_CTX failed: %d", ret);
|
|
closesocket();
|
|
return ret;
|
|
}
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int stopAccept() {
|
|
if (listenfd < 0) return -1;
|
|
hloop_t* loop = acceptor_loop->loop();
|
|
if (loop == NULL) return -2;
|
|
hio_t* listenio = hio_get(loop, listenfd);
|
|
assert(listenio != NULL);
|
|
return hio_del(listenio, HV_READ);
|
|
}
|
|
|
|
// start thread-safe
|
|
void start(bool wait_threads_started = true) {
|
|
if (worker_threads.threadNum() > 0) {
|
|
worker_threads.start(wait_threads_started);
|
|
}
|
|
acceptor_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::startAccept, this));
|
|
}
|
|
// stop thread-safe
|
|
void stop(bool wait_threads_stopped = true) {
|
|
closesocket();
|
|
if (worker_threads.threadNum() > 0) {
|
|
worker_threads.stop(wait_threads_stopped);
|
|
}
|
|
}
|
|
|
|
int withTLS(hssl_ctx_opt_t* opt = NULL) {
|
|
tls = true;
|
|
if (opt) {
|
|
if (tls_setting == NULL) {
|
|
HV_ALLOC_SIZEOF(tls_setting);
|
|
}
|
|
opt->endpoint = HSSL_SERVER;
|
|
*tls_setting = *opt;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
void setUnpack(unpack_setting_t* setting) {
|
|
if (setting == NULL) {
|
|
HV_FREE(unpack_setting);
|
|
return;
|
|
}
|
|
if (unpack_setting == NULL) {
|
|
HV_ALLOC_SIZEOF(unpack_setting);
|
|
}
|
|
*unpack_setting = *setting;
|
|
}
|
|
|
|
// channel
|
|
const TSocketChannelPtr& addChannel(hio_t* io) {
|
|
uint32_t id = hio_id(io);
|
|
auto channel = std::make_shared<TSocketChannel>(io);
|
|
std::lock_guard<std::mutex> locker(mutex_);
|
|
channels[id] = channel;
|
|
return channels[id];
|
|
}
|
|
|
|
TSocketChannelPtr getChannelById(uint32_t id) {
|
|
std::lock_guard<std::mutex> locker(mutex_);
|
|
auto iter = channels.find(id);
|
|
return iter != channels.end() ? iter->second : NULL;
|
|
}
|
|
|
|
void removeChannel(const TSocketChannelPtr& channel) {
|
|
uint32_t id = channel->id();
|
|
std::lock_guard<std::mutex> locker(mutex_);
|
|
channels.erase(id);
|
|
}
|
|
|
|
size_t connectionNum() {
|
|
std::lock_guard<std::mutex> locker(mutex_);
|
|
return channels.size();
|
|
}
|
|
|
|
int foreachChannel(std::function<void(const TSocketChannelPtr& channel)> fn) {
|
|
std::lock_guard<std::mutex> locker(mutex_);
|
|
for (auto& pair : channels) {
|
|
fn(pair.second);
|
|
}
|
|
return channels.size();
|
|
}
|
|
|
|
// broadcast thread-safe
|
|
int broadcast(const void* data, int size) {
|
|
return foreachChannel([data, size](const TSocketChannelPtr& channel) {
|
|
channel->write(data, size);
|
|
});
|
|
}
|
|
|
|
int broadcast(const std::string& str) {
|
|
return broadcast(str.data(), str.size());
|
|
}
|
|
|
|
private:
|
|
static void newConnEvent(hio_t* connio) {
|
|
TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
|
|
if (server->connectionNum() >= server->max_connections) {
|
|
hlogw("over max_connections");
|
|
hio_close(connio);
|
|
return;
|
|
}
|
|
|
|
// NOTE: attach to worker loop
|
|
EventLoop* worker_loop = currentThreadEventLoop;
|
|
assert(worker_loop != NULL);
|
|
hio_attach(worker_loop->loop(), connio);
|
|
|
|
const TSocketChannelPtr& channel = server->addChannel(connio);
|
|
channel->status = SocketChannel::CONNECTED;
|
|
|
|
channel->onread = [server, &channel](Buffer* buf) {
|
|
if (server->onMessage) {
|
|
server->onMessage(channel, buf);
|
|
}
|
|
};
|
|
channel->onwrite = [server, &channel](Buffer* buf) {
|
|
if (server->onWriteComplete) {
|
|
server->onWriteComplete(channel, buf);
|
|
}
|
|
};
|
|
channel->onclose = [server, &channel]() {
|
|
EventLoop* worker_loop = currentThreadEventLoop;
|
|
assert(worker_loop != NULL);
|
|
--worker_loop->connectionNum;
|
|
|
|
channel->status = SocketChannel::CLOSED;
|
|
if (server->onConnection) {
|
|
server->onConnection(channel);
|
|
}
|
|
server->removeChannel(channel);
|
|
// NOTE: After removeChannel, channel may be destroyed,
|
|
// so in this lambda function, no code should be added below.
|
|
};
|
|
|
|
if (server->unpack_setting) {
|
|
channel->setUnpack(server->unpack_setting);
|
|
}
|
|
channel->startRead();
|
|
if (server->onConnection) {
|
|
server->onConnection(channel);
|
|
}
|
|
}
|
|
|
|
static void onAccept(hio_t* connio) {
|
|
TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
|
|
// NOTE: detach from acceptor loop
|
|
hio_detach(connio);
|
|
EventLoopPtr worker_loop = server->worker_threads.nextLoop(server->load_balance);
|
|
if (worker_loop == NULL) {
|
|
worker_loop = server->acceptor_loop;
|
|
}
|
|
++worker_loop->connectionNum;
|
|
worker_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::newConnEvent, connio));
|
|
}
|
|
|
|
public:
|
|
std::string host;
|
|
int port;
|
|
int listenfd;
|
|
bool tls;
|
|
hssl_ctx_opt_t* tls_setting;
|
|
unpack_setting_t* unpack_setting;
|
|
// Callback
|
|
std::function<void(const TSocketChannelPtr&)> onConnection;
|
|
std::function<void(const TSocketChannelPtr&, Buffer*)> onMessage;
|
|
// NOTE: Use Channel::isWriteComplete in onWriteComplete callback to determine whether all data has been written.
|
|
std::function<void(const TSocketChannelPtr&, Buffer*)> onWriteComplete;
|
|
|
|
uint32_t max_connections;
|
|
load_balance_e load_balance;
|
|
|
|
private:
|
|
// id => TSocketChannelPtr
|
|
std::map<uint32_t, TSocketChannelPtr> channels; // GUAREDE_BY(mutex_)
|
|
std::mutex mutex_;
|
|
|
|
EventLoopPtr acceptor_loop;
|
|
EventLoopThreadPool worker_threads;
|
|
};
|
|
|
|
template<class TSocketChannel = SocketChannel>
|
|
class TcpServerTmpl : private EventLoopThread, public TcpServerEventLoopTmpl<TSocketChannel> {
|
|
public:
|
|
TcpServerTmpl(EventLoopPtr loop = NULL)
|
|
: EventLoopThread(loop)
|
|
, TcpServerEventLoopTmpl<TSocketChannel>(EventLoopThread::loop())
|
|
, is_loop_owner(loop == NULL)
|
|
{}
|
|
virtual ~TcpServerTmpl() {
|
|
stop(true);
|
|
}
|
|
|
|
EventLoopPtr loop(int idx = -1) {
|
|
return TcpServerEventLoopTmpl<TSocketChannel>::loop(idx);
|
|
}
|
|
|
|
// start thread-safe
|
|
void start(bool wait_threads_started = true) {
|
|
TcpServerEventLoopTmpl<TSocketChannel>::start(wait_threads_started);
|
|
if (!isRunning()) {
|
|
EventLoopThread::start(wait_threads_started);
|
|
}
|
|
}
|
|
|
|
// stop thread-safe
|
|
void stop(bool wait_threads_stopped = true) {
|
|
if (is_loop_owner) {
|
|
EventLoopThread::stop(wait_threads_stopped);
|
|
}
|
|
TcpServerEventLoopTmpl<TSocketChannel>::stop(wait_threads_stopped);
|
|
}
|
|
|
|
private:
|
|
bool is_loop_owner;
|
|
};
|
|
|
|
typedef TcpServerTmpl<SocketChannel> TcpServer;
|
|
|
|
}
|
|
|
|
#endif // HV_TCP_SERVER_HPP_
|