// Revision: 11/01/01 Dave Pape

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <ctype.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <map>
#include <CAVERN.hxx>
#include "ygDebugFlags.h"
#include "ygMutex.h"
#include "ygUtil.h"
#include "ygNetClient.h"

using namespace std;


class ygNetClientKey
	{
	public:
	 void *data;
	 size_t size;
	 size_t bufferSize;
	 bool owned;
	};

typedef map<ygString,ygNetClientKey *> ygNetClientKeymap;
typedef map<ygString,ygNetClientKeymap> ygNetClientPathmap;


struct _ygNetClientPrivateData
	{
	ygString server;
	int port;
	CAVERNnet_udp_c * udpSocket;
	CAVERNnet_tcpReflectorClient_c * tcpSocket;
	ygNetClientCallback callback;
	void * callbackData;
	ygNetClientPathmap paths;
	CAVERNnet_datapack_c packer;
	CAVERNnet_datapack_c tcpWritePacker;
	CAVERNnet_datapack_c udpWritePacker;
	char * tcpReadBuffer;
	char * tcpWriteBuffer;
	char * udpReadBuffer;
	char * udpWriteBuffer;
	ygMutex mapMutex, dataMutex;
	bool debugDownload, debugTCPWrite, debugTCPRead,
	     debugUDPWrite, debugUDPRead;
	};



ygNetClient::ygNetClient(void)
	{
	p_ = new struct _ygNetClientPrivateData;
	p_->udpSocket = NULL;
	p_->tcpSocket = NULL;
	p_->tcpReadBuffer = NULL;
	p_->tcpWriteBuffer = (char *)malloc(YG_NET_MAX_BUFFER_SIZE);
	p_->tcpWritePacker.initPack(p_->tcpWriteBuffer, YG_NET_MAX_BUFFER_SIZE);
	p_->udpReadBuffer = (char *)malloc(YG_NET_MAX_BUFFER_SIZE);
	p_->udpWriteBuffer = (char *)malloc(YG_NET_MAX_BUFFER_SIZE);
	p_->udpWritePacker.initPack(p_->udpWriteBuffer, YG_NET_MAX_BUFFER_SIZE);
	p_->callback = NULL;
	p_->debugDownload = ygDebugFlags::checkDebugEnv("net.download");
	p_->debugTCPWrite = ygDebugFlags::checkDebugEnv("net.tcpwrite");
	p_->debugTCPRead = ygDebugFlags::checkDebugEnv("net.tcpread");
	p_->debugUDPWrite = ygDebugFlags::checkDebugEnv("net.udpwrite");
	p_->debugUDPRead = ygDebugFlags::checkDebugEnv("net.udpread");
	}


void ygNetClient::init(const ygString& server,int port)
	{
	p_->server = server;
	p_->port = port;
	p_->udpSocket = new CAVERNnet_udp_c;
	p_->udpSocket->init(0);
	p_->udpSocket->setSendAddress(server.c_str(),port);
	p_->udpSocket->makeNonBlocking();
	p_->tcpSocket = new CAVERNnet_tcpReflectorClient_c;
	if (p_->tcpSocket->connectToServer(server.c_str(),port) ==
			CAVERNnet_tcpReflectorClient_c::FAILED)
		{
		cerr << "ERROR: ygNetClient::init: failed to connect"
			" to TCP server " << server << ":" << port << endl;
		return;
		}
	ygString cacheServer(server);
	int cachePort = port+1;
	if (getenv("YG_NET_CACHE_SERVER"))
		cacheServer = getenv("YG_NET_CACHE_SERVER");
	if (getenv("YG_NET_CACHE_PORT"))
		cachePort = atoi(getenv("YG_NET_CACHE_PORT"));
	downloadViaTCP(cacheServer,cachePort);
//	p_->udpSocket->enableInstrumentation();
	}


void ygNetClient::showStats(void)
	{
	p_->udpSocket->showStats("UDP","ygNetClient UDP socket");
	p_->tcpSocket->showStats("TCP","ygNetClient TCP socket");
	}


void ygNetClient::reset(void)
	{
	if (p_->tcpSocket)
		p_->tcpSocket->close();
	else
		p_->tcpSocket = new CAVERNnet_tcpReflectorClient_c;
	if (p_->tcpSocket->connectToServer(p_->server.c_str(),p_->port) ==
			CAVERNnet_tcpReflectorClient_c::FAILED)
		{
		cerr << "ERROR: ygNetClient::reset: failed to connect"
			" to TCP server " << p_->server << ":" << p_->port << endl;
		return;
		}
	ygString cacheServer(p_->server);
	int cachePort = p_->port+1;
	if (getenv("YG_NET_CACHE_SERVER"))
		cacheServer = getenv("YG_NET_CACHE_SERVER");
	if (getenv("YG_NET_CACHE_PORT"))
		cachePort = atoi(getenv("YG_NET_CACHE_PORT"));
	downloadViaTCP(cacheServer,cachePort);
	uploadViaTCP();
	}


void ygNetClient::trigger(ygNetClientCallback cb,void *data)
	{
	p_->callback = cb;
	p_->callbackData = data;
	}


void * ygNetClient::get(const ygString& path,const ygString& key,size_t *size)
	{
	ygNetClientKey * keystruct = getKeyStruct(path,key,false);
	if ((!keystruct) || (keystruct->size <= 0))
		{
		*size = 0;
		return 0;
		}
	p_->dataMutex.lock();
	*size = keystruct->size;
	void * retData = malloc(keystruct->size);
	memcpy(retData, keystruct->data, keystruct->size);
	p_->dataMutex.unlock();
	return retData;
	}


ygNetClientKey * ygNetClient::getKeyStruct(const ygString& path, const ygString& key,bool owned)
	{
	p_->mapMutex.lock();
	ygNetClientKey * keystruct = p_->paths[path][key];
	if (!keystruct)
		{
		keystruct = new ygNetClientKey;
		keystruct->data = NULL;
		keystruct->size = 0;
		keystruct->bufferSize = 0;
		keystruct->owned = owned;
		p_->paths[path][key] = keystruct;
		}
	p_->mapMutex.unlock();
	return keystruct;
	}


void ygNetClient::put(const ygString& path, const ygString& key, void *data, size_t size)
	{
	saveKeyData(path,key,data,size);
	sendKeyDataTCP(path,key,data,size);
	}


void ygNetClient::putUnreliably(const ygString& path, const ygString& key, void *data, size_t size)
	{
	saveKeyData(path,key,data,size);
	sendKeyDataUDP(path,key,data,size);
	}


void ygNetClient::saveKeyData(const ygString& path, const ygString& key, void *data, size_t size)
	{
	ygNetClientKey * keystruct = getKeyStruct(path,key,true);
	if (keystruct)
		{
		p_->dataMutex.lock();
		if (size > keystruct->bufferSize)
			{
			keystruct->bufferSize = size;
			keystruct->data = realloc(keystruct->data,keystruct->bufferSize);
			}
		memcpy(keystruct->data,data,size);
		keystruct->size = size;
		p_->dataMutex.unlock();
		}
	}


void ygNetClient::sendKeyDataTCP(const ygString& path, const ygString& key, void *data, size_t size, int sendType)
	{
	if (p_->tcpSocket)
		{
		int totalSize = sizeof(int) + 
				sizeof(int) + path.length()+1 +
				sizeof(int) + key.length()+1 +
				sizeof(int) + size;
		if (p_->tcpWritePacker.checkspace(totalSize) ==
		    CAVERNnet_datapack_c::FAILED)
			flushTCPWrite();
		p_->tcpWritePacker.packInt(sendType);
		p_->tcpWritePacker.packInt(path.length()+1);
		p_->tcpWritePacker.pack(path.c_str(), path.length()+1);
		p_->tcpWritePacker.packInt(key.length()+1);
		p_->tcpWritePacker.pack(key.c_str(), key.length()+1);
		p_->tcpWritePacker.packInt(size);
		p_->tcpWritePacker.pack((char *)data, size);
		}
	else
		sendKeyDataUDP(path, key, data, size, sendType);
	}


void ygNetClient::flushTCPWrite(void)
	{
	if ((p_->tcpSocket) && (p_->tcpWritePacker.getBufferFilledSize() > 0))
		{
		int nbytes = p_->tcpWritePacker.getBufferFilledSize();
		if (p_->debugTCPWrite)
			cvrnPrintf("ygNetClient: writing %d bytes via TCP\n", nbytes);
		p_->tcpSocket->write(p_->tcpWritePacker.getBuffer(),&nbytes);
		p_->tcpWritePacker.initPack(p_->tcpWriteBuffer,
					    YG_NET_MAX_BUFFER_SIZE);
		}
	}


void ygNetClient::sendKeyDataUDP(const ygString& path, const ygString& key, void *data, size_t size, int sendType)
	{
	if (p_->udpSocket)
		{
		int totalSize = sizeof(int) + 
				sizeof(int) + path.length()+1 +
				sizeof(int) + key.length()+1 +
				sizeof(int) + size;
		if (p_->udpWritePacker.checkspace(totalSize) ==
		    CAVERNnet_datapack_c::FAILED)
			flushUDPWrite();
		p_->udpWritePacker.packInt(sendType);
		p_->udpWritePacker.packInt(path.length()+1);
		p_->udpWritePacker.pack(path.c_str(), path.length()+1);
		p_->udpWritePacker.packInt(key.length()+1);
		p_->udpWritePacker.pack(key.c_str(), key.length()+1);
		p_->udpWritePacker.packInt(size);
		p_->udpWritePacker.pack((char *)data, size);
		flushUDPWrite();
		}
	}


void ygNetClient::fetch(const ygString& path,const ygString& key)
	{
#if 0
	if (p_->udpSocket)
		{
		int totalSize = sizeof(int) +
				sizeof(int) + path.length()+1 +
				sizeof(int) + key.length()+1;
		flushUDPWrite();
		p_->udpWritePacker.packInt(YG_NET_FETCH_KEY);
		p_->udpWritePacker.packInt(path.length()+1);
		p_->udpWritePacker.pack(path.c_str(), path.length()+1);
		p_->udpWritePacker.packInt(key.length()+1);
		p_->udpWritePacker.pack(key.c_str(), key.length()+1);
		flushUDPWrite();
		}
#endif
	}


void ygNetClient::flushUDPWrite(void)
	{
	if ((p_->udpSocket) && (p_->udpWritePacker.getBufferFilledSize() > 0))
		{
		int nbytes = p_->udpWritePacker.getBufferFilledSize();
		if (p_->debugUDPWrite)
			cvrnPrintf("ygNetClient: writing %d bytes via UDP\n", nbytes);
		p_->udpSocket->send(p_->udpWritePacker.getBuffer(),nbytes);
		p_->udpWritePacker.initPack(p_->udpWriteBuffer,
					    YG_NET_MAX_BUFFER_SIZE);
		}
	}


void ygNetClient::flushWrites(void)
	{
	flushTCPWrite();
	flushUDPWrite();
	}


void ygNetClient::process(void)
	{
	while (receivePacket())
		{
		while (true)
			{
			int packetType = -1;
			if (p_->packer.unpackInt(&packetType) ==
			    CAVERNnet_datapack_c::FAILED)
				break;
			else if (packetType == YG_NET_PUT_KEY)
				{
				ygString path, key;
				if (processKeyData(path,key,packetType))
					{
					if (p_->callback)
						(*p_->callback)(path.c_str(),key.c_str(),
								p_->callbackData);
					}
				}
			}
		}
	}


bool ygNetClient::receivePacket(void)
	{
	if (receiveTCPPacket())
		return true;
	return receiveUDPPacket();
	}


bool ygNetClient::processKeyData(ygString& path,ygString& key,int packetType)
	{
	int length;
	char *tmpstr;
/* Get path string length & data */
	p_->packer.unpackInt(&length);
	tmpstr = (char *) malloc(length);
	p_->packer.unpack(tmpstr, length);
	path.clear();
	path.append(tmpstr);
	free(tmpstr);
/* Get key string length & data */
	p_->packer.unpackInt(&length);
	tmpstr = (char *) malloc(length);
	p_->packer.unpack(tmpstr, length);
	key.clear();
	key.append(tmpstr);
	free(tmpstr);
/* Get data length & data */
	p_->packer.unpackInt(&length);
	ygNetClientKey * keystruct = getKeyStruct(path,key,false);
	if (keystruct)
		{
		p_->dataMutex.lock();
		if (length > keystruct->bufferSize)
			{
			keystruct->bufferSize = length;
			keystruct->data = realloc(keystruct->data,keystruct->bufferSize);
			}
		p_->packer.unpack((char *)keystruct->data, length);
		keystruct->size = length;
		p_->dataMutex.unlock();
		return true;
		}
	else
		return false;
	}


void ygNetClient::downloadViaTCP(const ygString& server,int port)
	{
	cvrnPrintf("ygNetClient::downloadViaTCP: connecting to %s:%d\n",
		   server.c_str(), port);
	CAVERNnet_tcpReflectorClient_c *sock = new CAVERNnet_tcpReflectorClient_c;
	if (sock->connectToServer(server.c_str(),port) ==
			CAVERNnet_tcpReflectorClient_c::FAILED)
		{
		cerr << "ERROR: ygNetClient::downloadViaTCP: failed to connect"
			" to cache server " << server << ":" << port << endl;
		return;
		}
	int nbytes;
	char buffer[64];
	p_->packer.initPack(buffer, sizeof(buffer));	
	p_->packer.packInt(YG_NET_REQUEST_ALL);
	nbytes = p_->packer.getBufferFilledSize();
	sock->write(p_->packer.getBuffer(),&nbytes);
	char * buf = NULL;
	if (sock->read(&buf, &nbytes, CAVERNnet_tcpReflectorClient_c::BLOCKING)
			!= CAVERNnet_tcpReflectorClient_c::OK)
		cerr << "WARNING: ygNetClient::downloadViaTCP: socket read failed\n";
	else
		{
		p_->packer.initUnpack(buf,nbytes);
		while (true)
			{
			int packetType = -1;
			p_->packer.unpackInt(&packetType);
			if (packetType == YG_NET_DOWNLOAD_DONE)
				break;
			else if (packetType == YG_NET_PUT_KEY)
				{
				ygString path, key;
				processKeyData(path,key);
				if (p_->debugDownload)
					cvrnPrintf("ygNetClient::downloadViaTCP: received %s/%s\n",
						   path.c_str(), key.c_str());
				}
			}
		}
	if (buf)
		delete buf;
	sock->close();
	cvrnPrintf("ygNetClient::downloadViaTCP: download done\n");
	}


void ygNetClient::uploadViaTCP(void)
	{
	cvrnPrintf("ygNetClient::uploadViaTCP: sending all keys to %s:%d\n",
		   p_->server.c_str(), p_->port);
	ygNetClientPathmap::const_iterator iter;
	for (iter = p_->paths.begin(); iter != p_->paths.end(); ++iter)
		{
		ygNetClientKeymap::const_iterator iter2 = iter->second.begin();
		if (iter2->second->owned)
			{
			for ( ; iter2 != iter->second.end(); ++iter2)
				sendKeyDataTCP(iter->first, iter2->first,
					    iter2->second->data, iter2->second->size,
					    YG_NET_PUT_KEY);
			}
		}
	cvrnPrintf("ygNetClient::uploadViaTCP: upload finished\n");
	}


bool ygNetClient::receiveTCPPacket(void)
	{
	if ((p_->tcpSocket) && (p_->tcpSocket->isReadyToRead() ==
				CAVERNnet_tcpReflectorClient_c::READY_TO_READ))
		{
		int numBytes=0;
		if (p_->tcpReadBuffer)
			delete p_->tcpReadBuffer;
		if (p_->tcpSocket->read(&p_->tcpReadBuffer, &numBytes,
					CAVERNnet_tcpReflectorClient_c::BLOCKING)
				== CAVERNnet_tcpReflectorClient_c::OK)
			{
			if (p_->debugTCPRead)
				cvrnPrintf("ygNetClient: received %d bytes via TCP\n", numBytes);
			p_->packer.initUnpack(p_->tcpReadBuffer,numBytes);
			return true;
			}
		p_->tcpReadBuffer = NULL;
		}
	return false;
	}


bool ygNetClient::receiveUDPPacket(void)
	{
	if (p_->udpSocket)
		{
		int numBytes = p_->udpSocket->receive((char*)p_->udpReadBuffer,
							YG_NET_MAX_BUFFER_SIZE);
		if (numBytes > 0)
			{
			if (p_->debugUDPRead)
				cvrnPrintf("ygNetClient: received %d bytes via UDP\n", numBytes);
			p_->packer.initUnpack((char *)p_->udpReadBuffer, numBytes);
			return true;
			}
		}
	return false;
	}
