#ifndef HV_TCP_SERVER_HPP_
#define HV_TCP_SERVER_HPP_

#include "hsocket.h"
#include "hssl.h"
#include "hlog.h"

#include "EventLoopThreadPool.h"
#include "Channel.h"

namespace hv {

template<class TSocketChannel = SocketChannel>
class TcpServerEventLoopTmpl {
public:
    typedef std::shared_ptr<TSocketChannel> TSocketChannelPtr;

    TcpServerEventLoopTmpl(EventLoopPtr loop = NULL) {
        acceptor_loop = loop ? loop : std::make_shared<EventLoop>();
        port = 0;
        listenfd = -1;
        tls = false;
        tls_setting = NULL;
        unpack_setting = NULL;
        max_connections = 0xFFFFFFFF;
        load_balance = LB_RoundRobin;
    }

    virtual ~TcpServerEventLoopTmpl() {
        HV_FREE(tls_setting);
        HV_FREE(unpack_setting);
    }

    EventLoopPtr loop(int idx = -1) {
        return worker_threads.loop(idx);
    }

    //@retval >=0 listenfd, <0 error
    int createsocket(int port, const char* host = "0.0.0.0") {
        listenfd = Listen(port, host);
        if (listenfd < 0) return listenfd;
        this->host = host;
        this->port = port;
        return listenfd;
    }
    // closesocket thread-safe
    void closesocket() {
        if (listenfd >= 0) {
            hloop_t* loop = acceptor_loop->loop();
            if (loop) {
                hio_t* listenio = hio_get(loop, listenfd);
                assert(listenio != NULL);
                hio_close_async(listenio);
            }
            listenfd = -1;
        }
    }

    void setMaxConnectionNum(uint32_t num) {
        max_connections = num;
    }

    void setLoadBalance(load_balance_e lb) {
        load_balance = lb;
    }

    // NOTE: totalThreadNum = 1 acceptor_thread + N worker_threads (N can be 0)
    void setThreadNum(int num) {
        worker_threads.setThreadNum(num);
    }

    int startAccept() {
        if (listenfd < 0) {
            listenfd = createsocket(port, host.c_str());
            if (listenfd < 0) {
                hloge("createsocket %s:%d return %d!\n", host.c_str(), port, listenfd);
                return listenfd;
            }
        }
        hloop_t* loop = acceptor_loop->loop();
        if (loop == NULL) return -2;
        hio_t* listenio = haccept(loop, listenfd, onAccept);
        assert(listenio != NULL);
        hevent_set_userdata(listenio, this);
        if (tls) {
            hio_enable_ssl(listenio);
            if (tls_setting) {
                int ret = hio_new_ssl_ctx(listenio, tls_setting);
                if (ret != 0) {
                    hloge("new SSL_CTX failed: %d", ret);
                    closesocket();
                    return ret;
                }
            }
        }
        return 0;
    }

    int stopAccept() {
        if (listenfd < 0) return -1;
        hloop_t* loop = acceptor_loop->loop();
        if (loop == NULL) return -2;
        hio_t* listenio = hio_get(loop, listenfd);
        assert(listenio != NULL);
        return hio_del(listenio, HV_READ);
    }

    // start thread-safe
    void start(bool wait_threads_started = true) {
        if (worker_threads.threadNum() > 0) {
            worker_threads.start(wait_threads_started);
        }
        acceptor_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::startAccept, this));
    }
    // stop thread-safe
    void stop(bool wait_threads_stopped = true) {
        closesocket();
        if (worker_threads.threadNum() > 0) {
            worker_threads.stop(wait_threads_stopped);
        }
    }

    int withTLS(hssl_ctx_opt_t* opt = NULL) {
        tls = true;
        if (opt) {
            if (tls_setting == NULL) {
                HV_ALLOC_SIZEOF(tls_setting);
            }
            opt->endpoint = HSSL_SERVER;
            *tls_setting = *opt;
        }
        return 0;
    }

    void setUnpack(unpack_setting_t* setting) {
        if (setting == NULL) {
            HV_FREE(unpack_setting);
            return;
        }
        if (unpack_setting == NULL) {
            HV_ALLOC_SIZEOF(unpack_setting);
        }
        *unpack_setting = *setting;
    }

    // channel
    const TSocketChannelPtr& addChannel(hio_t* io) {
        uint32_t id = hio_id(io);
        auto channel = std::make_shared<TSocketChannel>(io);
        std::lock_guard<std::mutex> locker(mutex_);
        channels[id] = channel;
        return channels[id];
    }

    TSocketChannelPtr getChannelById(uint32_t id) {
        std::lock_guard<std::mutex> locker(mutex_);
        auto iter = channels.find(id);
        return iter != channels.end() ? iter->second : NULL;
    }

    void removeChannel(const TSocketChannelPtr& channel) {
        uint32_t id = channel->id();
        std::lock_guard<std::mutex> locker(mutex_);
        channels.erase(id);
    }

    size_t connectionNum() {
        std::lock_guard<std::mutex> locker(mutex_);
        return channels.size();
    }

    int foreachChannel(std::function<void(const TSocketChannelPtr& channel)> fn) {
        std::lock_guard<std::mutex> locker(mutex_);
        for (auto& pair : channels) {
            fn(pair.second);
        }
        return channels.size();
    }

    // broadcast thread-safe
    int broadcast(const void* data, int size) {
        return foreachChannel([data, size](const TSocketChannelPtr& channel) {
            channel->write(data, size);
        });
    }

    int broadcast(const std::string& str) {
        return broadcast(str.data(), str.size());
    }

private:
    static void newConnEvent(hio_t* connio) {
        TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
        if (server->connectionNum() >= server->max_connections) {
            hlogw("over max_connections");
            hio_close(connio);
            return;
        }

        // NOTE: attach to worker loop
        EventLoop* worker_loop = currentThreadEventLoop;
        assert(worker_loop != NULL);
        hio_attach(worker_loop->loop(), connio);

        const TSocketChannelPtr& channel = server->addChannel(connio);
        channel->status = SocketChannel::CONNECTED;

        channel->onread = [server, &channel](Buffer* buf) {
            if (server->onMessage) {
                server->onMessage(channel, buf);
            }
        };
        channel->onwrite = [server, &channel](Buffer* buf) {
            if (server->onWriteComplete) {
                server->onWriteComplete(channel, buf);
            }
        };
        channel->onclose = [server, &channel]() {
            EventLoop* worker_loop = currentThreadEventLoop;
            assert(worker_loop != NULL);
            --worker_loop->connectionNum;

            channel->status = SocketChannel::CLOSED;
            if (server->onConnection) {
                server->onConnection(channel);
            }
            server->removeChannel(channel);
            // NOTE: After removeChannel, channel may be destroyed,
            // so in this lambda function, no code should be added below.
        };

        if (server->unpack_setting) {
            channel->setUnpack(server->unpack_setting);
        }
        channel->startRead();
        if (server->onConnection) {
            server->onConnection(channel);
        }
    }

    static void onAccept(hio_t* connio) {
        TcpServerEventLoopTmpl* server = (TcpServerEventLoopTmpl*)hevent_userdata(connio);
        // NOTE: detach from acceptor loop
        hio_detach(connio);
        EventLoopPtr worker_loop = server->worker_threads.nextLoop(server->load_balance);
        if (worker_loop == NULL) {
            worker_loop = server->acceptor_loop;
        }
        ++worker_loop->connectionNum;
        worker_loop->runInLoop(std::bind(&TcpServerEventLoopTmpl::newConnEvent, connio));
    }

public:
    std::string             host;
    int                     port;
    int                     listenfd;
    bool                    tls;
    hssl_ctx_opt_t*         tls_setting;
    unpack_setting_t*       unpack_setting;
    // Callback
    std::function<void(const TSocketChannelPtr&)>           onConnection;
    std::function<void(const TSocketChannelPtr&, Buffer*)>  onMessage;
    // NOTE: Use Channel::isWriteComplete in onWriteComplete callback to determine whether all data has been written.
    std::function<void(const TSocketChannelPtr&, Buffer*)>  onWriteComplete;

    uint32_t                max_connections;
    load_balance_e          load_balance;

private:
    // id => TSocketChannelPtr
    std::map<uint32_t, TSocketChannelPtr>   channels; // GUAREDE_BY(mutex_)
    std::mutex                              mutex_;

    EventLoopPtr            acceptor_loop;
    EventLoopThreadPool     worker_threads;
};

template<class TSocketChannel = SocketChannel>
class TcpServerTmpl : private EventLoopThread, public TcpServerEventLoopTmpl<TSocketChannel> {
public:
    TcpServerTmpl(EventLoopPtr loop = NULL)
        : EventLoopThread(loop)
        , TcpServerEventLoopTmpl<TSocketChannel>(EventLoopThread::loop())
        , is_loop_owner(loop == NULL)
    {}
    virtual ~TcpServerTmpl() {
        stop(true);
    }

    EventLoopPtr loop(int idx = -1) {
        return TcpServerEventLoopTmpl<TSocketChannel>::loop(idx);
    }

    // start thread-safe
    void start(bool wait_threads_started = true) {
        TcpServerEventLoopTmpl<TSocketChannel>::start(wait_threads_started);
        EventLoopThread::start(wait_threads_started);
    }

    // stop thread-safe
    void stop(bool wait_threads_stopped = true) {
        if (is_loop_owner) {
            EventLoopThread::stop(wait_threads_stopped);
        }
        TcpServerEventLoopTmpl<TSocketChannel>::stop(wait_threads_stopped);
    }

private:
    bool is_loop_owner;
};

typedef TcpServerTmpl<SocketChannel> TcpServer;

}

#endif // HV_TCP_SERVER_HPP_