//
// Created by Nathan Touroux on 29/06/2017.
//

#include "Network.h"

//*********************************************** TCP DATA ***********************************************

NetData::NetData(std::string str): data(nullptr), size(str.size()) {
    copy((void*)str.c_str(), size);
}
NetData::NetData(const char *str): data(nullptr), size(std::strlen(str)) {
    copy((void*)str, size);
}
NetData::NetData(void *data, size_data size, bool copyData) : data(data), size(size) {
    if(copyData) copy(data, size);
}
NetData::NetData(const NetData &toCopy): data(nullptr), size(toCopy.size){
    copy(toCopy.data, size);
}
NetData& NetData::operator=(const NetData &toCopy){
    size = toCopy.size;
    free(data);
    copy(toCopy.data, size);
    return *this;
}
NetData::~NetData(){
    if(data != nullptr) free(data);
}

void NetData::copy(void* data, size_data size){
    if(size == 0){
        this->data = nullptr;
    }else{
        this->data = malloc(size);
        memcpy(this->data, data, size);
    }
}

bool NetData::valid(){
    return data != nullptr && size != 0;
}

NetData::operator std::string() {
    if(!valid()) return "";
    char arr[size+1];
    memcpy(arr, data, size);
    arr[size] = '\0';
    return std::string(arr);
}

void* NetData::getData(){
    auto tmp = data;
    copy(data, size);
    return tmp;
}

size_data NetData::getSize(){
    return size;
}

NetPacket NetData::toPacket(){
    if(!valid()) return {nullptr, 0};

    auto sizePacket = sizeof(size_data)+size;
    auto packet = (char*)malloc(sizePacket);
    memcpy(packet, &size, sizeof(size_data));
    memcpy(packet+sizeof(size_data), data, size);
    return {packet, sizePacket};
}

//**************************************************** TCP Client ****************************************************
TCPClient::TCPClient(bool asynchronous): sock(), server(nullptr), first(nullptr), nbData(0), asyncThread(nullptr), asynchronous(asynchronous), connected(false) {

}
TCPClient::TCPClient(const TCPClient &client): sock(), server(nullptr), first(nullptr), nbData(0), asyncThread(nullptr), asynchronous(client.asynchronous.load()), connected(false){

}

TCPClient::~TCPClient(){
    disconnect();
}

bool TCPClient::invalid(){
    return !connected;
}

bool TCPClient::accept(TCPServer *server, bool receiveName){
    if(connected) return false;//if already connected stop there

    sock.sock = ::accept(server->getHostSock().sock, (SOCKADDR*)&sock.sin, &sock.recsize);

    if(sock.sock == INVALID_SOCKET) return false;

    sock.IP = inet_ntoa(sock.sin.sin_addr);

    if(receiveName){
        std::string name = unsafeReceive();
        sock.name = name;
    }

    this->server = server;

    connected = true;
    if(asynchronous) asyncThread = new std::thread(asyncReceive, this);

    return sock.sock != SOCKET_ERROR;
}

bool TCPClient::connect(std::string IP, int port, std::string name){
    if(connected) return false;//if already connected stop there

    sock.IP = IP;
    sock.port = port;

    sock.sock = socket(AF_INET,SOCK_STREAM, 0);

    if(sock.sock == INVALID_SOCKET) return false;

    sock.sin.sin_addr.s_addr = inet_addr(IP.c_str());
    sock.sin.sin_family = AF_INET;
    sock.sin.sin_port = htons(port);

    if(::connect(sock.sock, (SOCKADDR*)&sock.sin, sock.recsize) == SOCKET_ERROR) return false;

    if(!name.empty()){
        unsafeSend(name);
        unsafeReceive();//to confirm if name is accepted
    }

    connected = true;
    if(asynchronous) asyncThread = new std::thread(asyncReceive, this);

    return sock.sock != SOCKET_ERROR;
}

void TCPClient::disconnect(bool removeFromServer) {
    if(!connected) return; //nothing to do
    if(server && removeFromServer){
        server->eraseClient(*this);
    }else{
        connected = false;
        closesocket(sock.sock);
        sock.sock = SOCKET_ERROR;
        sock.name = "";
        sock.IP = "";
        //std::lock_guard<std::mutex> guard(mutex);
        if(asyncThread){
            std::cout << "1" << std::endl;//TODO remove
            if(asyncThread->joinable())
                asyncThread->detach();//TODO sometimes too slow
            delete asyncThread;
            asyncThread = nullptr;
        }
    }
}

bool TCPClient::isConnected(){
    return connected;
}

void TCPClient::enableAsynchronous(){
    if(asynchronous) return;//if already asynchronous, nothing to do
    asynchronous = true;
    if(connected) asyncThread = new std::thread(asyncReceive, this);
}

bool TCPClient::asynchronousEnabled() {
    return asynchronous;
}

void TCPClient::asyncReceive(TCPClient *client) {
    while(client->isConnected() && client->asynchronousEnabled()){
        auto data = client->unsafeReceive();
        if(data.valid())
            client->push(data.getData(), data.getSize());
    }
}
void TCPClient::push(void* data, size_data size){
    std::lock_guard<std::mutex> guard(mutex);
    Data *newdata = new Data{data, size, nullptr};
    if(first == nullptr){
        first = newdata;
    }else{
        last->next = newdata;
    }
    last = newdata;
    nbData++;
}
NetData TCPClient::pop(){
    std::lock_guard<std::mutex> guard(mutex);
    if(first == nullptr) return NetData(nullptr, 0);
    Data *data = first;
    first = first->next;
    NetData netdata(data->data, data->size, false);
    delete data;
    return netdata;
}

long TCPClient::unsafeSend(NetData data){
    auto packet = data.toPacket();
    long size = ::send(sock.sock, packet.data, packet.size, 0);
    free(packet.data);
    if(size == 0) if(errno == ECONNRESET) disconnect();//when sending check ECONNRESET
    return size;
}
NetData TCPClient::unsafeReceive(){
    size_data sizeByte;
    long size = ::recv(sock.sock, &sizeByte, sizeof(size_data), 0);
    if(size == 0) {
        disconnect();
        return  NetData(nullptr, 0);
    }//when receiving just check for size == 0
    if(sizeByte == 0) return NetData(nullptr, 0);//check that the length of data is not 0

    void* data = malloc(sizeByte);
    size = ::recv(sock.sock, data, sizeByte, 0);
    if(size == 0) {
        disconnect();
        free(data);
        return NetData(nullptr, 0);
    }

    return NetData(data, sizeByte, false);//dont copy the data just the address so it will be freed automatically after
}

long TCPClient::rawSend(void *data, size_data sizeByte) {
    if(invalid() || asynchronous) return SOCKET_ERROR;//cant send raw packet when asynchronous
    long size = ::send(sock.sock, data, sizeByte, 0);
    if(size == 0) if(errno == ECONNRESET) disconnect();//when sending check ECONNRESET
    return size;
}

long TCPClient::rawReceive(void *data, size_data sizeByte) {
    if(invalid() || asynchronous) return SOCKET_ERROR;//cant receive raw packet when asynchronous
    long size = ::recv(sock.sock, data, sizeByte, 0);
    if(size == 0) disconnect();//when receiving just check for size == 0
    return size;
}

long TCPClient::send(NetData data){
    if(invalid()) return SOCKET_ERROR;
    return unsafeSend(data);
}
NetData TCPClient::receive(){
    if(invalid()) return NetData(nullptr, 0);
    if(asynchronous || nbData > 0)//if not asynchronous but data where stored asynchronously before, return these data
        // before using blocking receive
        return pop();
    else
        return unsafeReceive();
}

std::string TCPClient::getIP() const{
    return sock.IP;
}

std::string TCPClient::getName() const{
    return sock.name;
}

int TCPClient::getPort() const{
    return sock.port;
}

int TCPClient::getNbData() const{
    return nbData;
}

//**************************************************** TCP Serveur ****************************************************

TCPServer::TCPServer(int port, unsigned int nbConnections, bool useName, bool asynchronous):
        clients(nbConnections, asynchronous), thread(nullptr), port(port), error(0), nbConnections(nbConnections),
        nbConnected(0), useName(useName), hosting(true), autoReconnect(false), maxConnectionTry(0){
#ifdef WIN32
    error = WSAStartup(MAKEWORD(2,2), &m_WSAData);
#else
    error = 0;
#endif // WIN32
}

TCPServer::~TCPServer(){
    clients.clear();
    closesocket(hsock.sock);
    if(thread){
        hosting = false;
        thread->join();
        delete thread;
        thread = nullptr;
    }

#ifdef WIN32
    WSACleanup();
#endif // WIN32
}


bool TCPServer::host(bool waitConnections, bool autoReconnect, unsigned int maxConnectionTry) {
    if(error) return false;

        hsock.sock = socket(AF_INET,SOCK_STREAM, 0);

    if(hsock.sock == INVALID_SOCKET) return false;

        hsock.sin.sin_addr.s_addr = htonl(INADDR_ANY);
        hsock.sin.sin_family = AF_INET;
        hsock.sin.sin_port = htons(port);

    if(bind(hsock.sock, (SOCKADDR*)&hsock.sin, sizeof(hsock.sin)) == SOCKET_ERROR) return false;

    if(listen(hsock.sock, 1) == SOCKET_ERROR) return false;

    this->autoReconnect = autoReconnect;
    this->maxConnectionTry = maxConnectionTry;
    if(waitConnections){
        return acceptHost();
    }else{
        thread = new std::thread(waitHost, this);
    }

    return true;
}
bool TCPServer::acceptHost(){
    int connectionTry = 0;
    for(int i = 0;(i<nbConnections || autoReconnect) && hosting;i++){
        if(i == nbConnections) i = 0;//in case of autoReconnect
        if(clients[i].isConnected()) continue;

        if(!clients[i].accept(this, useName)){
            connectionTry++;
            if(connectionTry <= maxConnectionTry)
                return false;
            i--;
            continue;
        }
        connectionTry = 0;//connection made so reset connectionTry to 0

        std::string name = clients[i].getName();
        if(!name.empty()){
            auto client = clientsByName.find(name);
            if(client != clientsByName.end()){//if already exist disconnect
                if(client->second->isConnected() || !autoReconnect){//if !autoReconnect a client can't connect again
                    clients[i].disconnect(false);
                    i--;
                    continue;
                }else{//else send acceptation message
                    clients[i].send("accepted");
                }
            } else{//else send acceptation message
                clients[i].send("accepted");
            }
            clientsByName[name] = &clients[i];
        }
        clientsByIP[clients[i].getIP()] = &clients[i];

        nbConnected++;
    }

    return true;
}

void TCPServer::waitHost(TCPServer *tcp){
    tcp->acceptHost();
}

TCPClient& TCPServer::operator[](int index){
    if(invalid(index)) return empty;
    return clients[index];
}
TCPClient& TCPServer::operator[](std::string IP){
    if(clientsByIP.find(IP) == clientsByIP.end()) return empty;
    return *clientsByIP[IP];
}
TCPClient& TCPServer::operator()(std::string name){
    if(clientsByName.find(name) == clientsByName.end()) return empty;
    return *clientsByName[name];
}
TCPClient& TCPServer::unavailable(){
    return empty;
}

bool TCPServer::invalid(int client){
    if(client >= 0 && client < clients.size()) return !clients[client].isConnected();
    return true;
}

void TCPServer::eraseClient(TCPClient &client){
    if(client.isConnected()){
        auto IP = client.getIP();
        auto name = client.getName();
        client.disconnect(false);
        nbConnected--;
        if(!IP.empty())
            clientsByIP.erase(IP);
        if(!name.empty())
            clientsByName.erase(name);
    }
}

std::vector<std::string> TCPServer::IPList(){
    std::vector<std::string> IPs;
    for(auto const &elt : clientsByIP)
        IPs.push_back(elt.first);
    return IPs;
}
std::vector<std::string> TCPServer::namesList(){
    std::vector<std::string> names;
    for(auto const &elt : clientsByName)
        names.push_back(elt.first);
    return names;
}

Sock TCPServer::getHostSock()  const{
    return hsock;
}

int TCPServer::getNbConnected()  const{
    return nbConnected;
}

//**************************************************** UDP ****************************************************

UDP::UDP(std::string IP, int port): IP(IP), port(port), sock(), from({0}), fromaddrsize(sizeof(from)){
}
UDP::~UDP(){
    closesocket(sock.sock);
}
bool UDP::init(){
    sock.sock = socket(PF_INET, SOCK_DGRAM, 0);

    if(sock.sock == INVALID_SOCKET) return false;

    sock.sin = { 0 };
    sock.recsize = sizeof sock.sin;
    if(IP == "")
        sock.sin.sin_addr.s_addr = htonl(INADDR_ANY);
    else
        sock.sin.sin_addr.s_addr = inet_addr(IP.c_str());
    sock.sin.sin_port = htons(port);
    sock.sin.sin_family = AF_INET;

    if(bind(sock.sock, (SOCKADDR *)&sock.sin, sock.recsize) == SOCKET_ERROR) return false;

    return true;
}


long UDP::send(NetData data){
    auto packet = data.toPacket();
    long size = sendto(sock.sock, packet.data, packet.size, 0, (SOCKADDR *)&sock.sin, sock.recsize);
    free(packet.data);
    return size;
}
NetData UDP::receive(){
    size_data sizeByte;
    long size = recvfrom(sock.sock, &sizeByte, sizeof(size_data), 0, &from, &fromaddrsize);
    if(size == 0 || sizeByte == 0) {
        return  NetData(nullptr, 0);
    }//when receiving just check for size == 0

    void* data = malloc(sizeByte);
    size = recvfrom(sock.sock, data, sizeByte, 0, &from, &fromaddrsize);
    if(size == 0) {
        free(data);
        return NetData(nullptr, 0);
    }

    return NetData(data, sizeByte, false);//dont copy the data just the address so it will be freed automatically after
}

long UDP::rawSend(void *data, size_data sizeByte) {
    return sendto(sock.sock, data, sizeByte, 0, (SOCKADDR *)&sock.sin, sock.recsize);
}

long UDP::rawReceive(void *data, size_data sizeByte) {
    return recvfrom(sock.sock, data, sizeByte, 0, &from, &fromaddrsize);
}

int UDP::getPort(){
    return port;
}
std::string UDP::getIP(){
    if(IP != "")
        return IP;
    else{
        char host[NI_MAXHOST];
        if (getnameinfo((sockaddr*)&from, fromaddrsize, host, NI_MAXHOST, NULL, 0, NI_NUMERICHOST) != 0) {
            return "";
        } else {
            return std::string(host);
        }
    }
}

    //************************************ STATIC ************************************

long UDP::rawSendTo(std::string IP, int port, void *data, size_data sizeByte){
    SOCKET sock = socket(PF_INET, SOCK_DGRAM, 0);

    if(sock == INVALID_SOCKET) return false;

    SOCKADDR_IN to = { 0 };
    socklen_t tosize = sizeof to;
    to.sin_addr.s_addr = inet_addr(IP.c_str());
    to.sin_port = htons(port);
    to.sin_family = AF_INET;

    long size = sendto(sock, data, sizeByte, 0, (SOCKADDR *)&to, tosize);
    closesocket(sock);

    return  size;
}

long UDP::sendTo(std::string IP, int port, NetData data){
    SOCKET sock = socket(PF_INET, SOCK_DGRAM, 0);

    if(sock == INVALID_SOCKET) return false;

    SOCKADDR_IN to = { 0 };
    socklen_t tosize = sizeof to;
    to.sin_addr.s_addr = inet_addr(IP.c_str());
    to.sin_port = htons(port);
    to.sin_family = AF_INET;

    auto packet = data.toPacket();
    long size = sendto(sock, packet.data, packet.size, 0, (SOCKADDR *)&to, tosize);
    free(packet.data);
    closesocket(sock);

    return  size;
}

long UDP::rawReceiveFrom(std::string IP, int port, void *data, size_data sizeByte){
    SOCKET sock = socket(PF_INET, SOCK_DGRAM, 0);

    if(sock == INVALID_SOCKET) return false;

    SOCKADDR_IN to = { 0 };
    socklen_t tosize = sizeof to;
    if(IP.empty())
        to.sin_addr.s_addr = htonl(INADDR_ANY);
    else
        to.sin_addr.s_addr = inet_addr(IP.c_str());
    to.sin_port = htons(port);
    to.sin_family = AF_INET;

    if(bind(sock, (SOCKADDR *)&to, tosize) == SOCKET_ERROR) return false;

    sockaddr from = { 0 };
    socklen_t addrsize = sizeof from;

    long size = recvfrom(sock, data, sizeByte, 0, &from, &addrsize);
    closesocket(sock);

    return  size;
}

long UDP::rawReceiveFrom(int port, void *data, size_data sizeByte){
    return rawReceiveFrom("", port, data, sizeByte);
}//TODO add possibility to know from wich ip it comes from

NetData UDP::receiveFrom(std::string IP, int port){
    SOCKET sock = socket(PF_INET, SOCK_DGRAM, 0);

    if(sock == INVALID_SOCKET) return NetData(nullptr, 0);

    SOCKADDR_IN to = { 0 };
    socklen_t tosize = sizeof to;
    if(IP.empty())
        to.sin_addr.s_addr = htonl(INADDR_ANY);
    else
        to.sin_addr.s_addr = inet_addr(IP.c_str());
    to.sin_port = htons(port);
    to.sin_family = AF_INET;

    if(bind(sock, (SOCKADDR *)&to, tosize) == SOCKET_ERROR) return NetData(nullptr, 0);

    sockaddr from = { 0 };
    socklen_t addrsize = sizeof from;

    size_data sizeByte;
    long size = recvfrom(sock, &sizeByte, sizeof(size_data), 0, &from, &addrsize);
    if(size == 0 || sizeByte == 0) {
        return  NetData(nullptr, 0);
    }//when receiving just check for size == 0

    void* data = malloc(sizeByte);
    size = recvfrom(sock, data, sizeByte, 0, &from, &addrsize);
    if(size == 0) {
        free(data);
        return NetData(nullptr, 0);
    }

    return NetData(data, sizeByte, false);//dont copy the data just the address so it will be freed automatically after
}

NetData UDP::receiveFrom(int port){
    return receiveFrom("", port);//dont copy the data just the address so it will be freed automatically after
}