#include "Net.h"

#define WIN32_LEAN_AND_MEAN

#include <WinSock2.h>
#include <WS2tcpip.h>

#pragma comment(lib, "Ws2_32.lib")

#include <SDL.h>

#include "Log.h"

#define DEFAULT_PORT "20904"

#define NET_BUFFERSIZE 65535
#define NET_RWBUFSIZE 16384

static char *readbuf;

static struct sockaddr_in NetaddrToSockaddr(Netaddr netaddr)
{
	struct sockaddr_in result;
	result.sin_family = AF_INET;
	result.sin_addr.S_un.S_addr = netaddr.addr;
	result.sin_port = netaddr.port;
	return result;
}

static Netaddr SockaddrToNetaddr(struct sockaddr_in sockaddr)
{
	Netaddr result;
	result.addr = sockaddr.sin_addr.S_un.S_addr;
	result.port = sockaddr.sin_port;
	return result;
}

int Net_Init(void)
{
	WSADATA wsaData;
	int iResult;

	iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
	if (iResult != 0)
	{
		Msg_Error("Failed to initialize networking API: %d", iResult);
		return 0;
	}

	readbuf = SDL_malloc(NET_RWBUFSIZE);

	return 1;
}

void Net_Cleanup(void)
{
	SDL_free(readbuf);
	WSACleanup();
}

int Net_OpenServer(struct NetConnection *conn)
{
	int iResult;
	u_long mode = 1;
	struct addrinfo *result = NULL, *ptr = NULL, hints;
	SOCKET listenSocket = INVALID_SOCKET;

	ZeroMemory(&hints, sizeof(hints));
	hints.ai_family = AF_INET;
	hints.ai_socktype = SOCK_DGRAM;
	hints.ai_protocol = IPPROTO_UDP;
	hints.ai_flags = AI_PASSIVE;

	iResult = getaddrinfo(NULL, DEFAULT_PORT, &hints, &result);
	if (iResult != 0)
	{
		Msg_Error("getaddrinfo failed: %s", gai_strerrorA(iResult));
		return 0;
	}

	listenSocket = socket(result->ai_family, result->ai_socktype, result->ai_protocol);
	if (listenSocket == INVALID_SOCKET)
	{
		Msg_Error("Failed to open socket: %d", WSAGetLastError());
		freeaddrinfo(result);
		return 0;
	}

	iResult = ioctlsocket(listenSocket, FIONBIO, &mode);
	if (iResult == SOCKET_ERROR)
	{
		Msg_Error("Failed to set socket non-blocking mode: %d", WSAGetLastError());
		freeaddrinfo(result);
		return 0;
	}

	iResult = bind(listenSocket, result->ai_addr, (int)result->ai_addrlen);
	if (iResult == SOCKET_ERROR)
	{
		Msg_Error("Failed to bind: %d", WSAGetLastError());
		freeaddrinfo(result);
		closesocket(listenSocket);
		return 0;
	}

	freeaddrinfo(result);

	ZeroMemory(conn, sizeof(struct NetConnection));
	conn->connType = NET_CONN_SERVER;
	conn->socket = listenSocket;

	return 1;
}

int Net_ConnectClient(struct NetConnection *conn, const char *addrString)
{
	int iResult;
	u_long mode = 1;
	struct addrinfo *result = NULL, *ptr = NULL, hints;
	struct sockaddr_in *sockaddr;
	SOCKET connectSocket;

	ZeroMemory(&hints, sizeof(hints));
	hints.ai_family = AF_INET;
	hints.ai_socktype = SOCK_DGRAM;
	hints.ai_protocol = IPPROTO_UDP;

	iResult = getaddrinfo(addrString, DEFAULT_PORT, &hints, &result);
	if (iResult != 0)
	{
		Msg_Error("Failed to resolve server address: %d", iResult);
		return 0;
	}

	ptr = result;
	connectSocket = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
	if (connectSocket == INVALID_SOCKET)
	{
		Msg_Error("socket failed: %d", WSAGetLastError());
		freeaddrinfo(result);
		return 0;
	}

	iResult = ioctlsocket(connectSocket, FIONBIO, &mode);
	if (iResult == SOCKET_ERROR)
	{
		Msg_Error("Failed to set socket non-blocking mode: %d", WSAGetLastError());
		freeaddrinfo(result);
		return 0;
	}

	iResult = connect(connectSocket, ptr->ai_addr, (int)ptr->ai_addrlen);
	if (iResult == SOCKET_ERROR)
	{
		closesocket(connectSocket);
		connectSocket = INVALID_SOCKET;
	}

	sockaddr = (struct sockaddr_in *)result->ai_addr;

	// Should really try the next address returned by getaddrinfo
	// if the connect call failed
	// But for this simple example we just free the resources
	// returned by getaddrinfo and print an error message

	if (connectSocket == INVALID_SOCKET)
	{
		Msg_Error("Unable to connect to server: %d", WSAGetLastError());
		return 0;
	}

	ZeroMemory(conn, sizeof(struct NetConnection));
	conn->connType = NET_CONN_CLIENT;
	conn->status = NET_STATUS_CONNECTED;
	conn->socket = connectSocket;
	conn->serverAddr = SockaddrToNetaddr(*sockaddr);

	Msg_Info("Connected to server at %s", Net_AddrToString(conn->serverAddr));

	freeaddrinfo(result);

	return 1;
}

void Net_Close(struct NetConnection *conn)
{
	int iResult;
	iResult = shutdown(conn->socket, SD_SEND);
	if (iResult == SOCKET_ERROR)
	{
		Msg_Error("shutdown failed: %d:", WSAGetLastError());
	}

	closesocket(conn->socket);
}

int Net_CurrentTimeMs(void)
{
	return SDL_GetTicks();
}

char *Net_AddrToString(Netaddr addr)
{
	static char strbuf[64];
	SDL_snprintf(strbuf, 64, "%d.%d.%d.%d:%d",
		addr.addr & 0xff,
		(addr.addr >> 8) & 0xff,
		(addr.addr >> 16) & 0xff,
		(addr.addr >> 24) & 0xff,
		addr.port
	);
	return strbuf;
}

void Net_SendPacket(struct NetConnection *conn, Netaddr destAddr, void *data, int size)
{
	int bytesSent;

	if (conn->status != NET_STATUS_CONNECTED)
		return;

	if (size >= NET_RWBUFSIZE)
	{
		Msg_Warning("Packet of %d bytes was skipped because too large", size);
		return;
	}

	struct sockaddr_in sockaddr = NetaddrToSockaddr(destAddr);
	bytesSent = sendto(conn->socket, data, size, 0, (struct sockaddr *)&sockaddr, sizeof(struct sockaddr_in));
	if (bytesSent == SOCKET_ERROR)
	{
		Msg_Warning("Failed to send packet: %d", WSAGetLastError());
	}
	//Msg_Info("Sent packet type %d (%d bytes)", type, bytesSent);
}

int Net_ReadNextPacket(struct NetConnection *conn, Netaddr *fromAddr, void **data, int *size)
{
	struct sockaddr_in addr;
	int addrlen = sizeof(addr);
	int rc;

	rc = recvfrom(conn->socket, readbuf, NET_RWBUFSIZE, 0, (struct sockaddr *)&addr, &addrlen);
	if (rc > 0)
	{
		*data = readbuf;
		*size = rc;

		*fromAddr = SockaddrToNetaddr(addr);
		//Msg_Info("Received packet type %d (%d bytes) from %s", packetOut->type, packetOut->size, Net_AddrToString(*fromAddr));
		conn->status = NET_STATUS_CONNECTED;

		return 1;
	}
	else
	{
		int err = WSAGetLastError();
		switch (err)
		{
		case WSAENOTCONN:
		case WSAEWOULDBLOCK:
			/* Nothing needs to be read, continue on */
			break;
		case WSAEINTR:
			/* The connection was closed */
			conn->status = NET_STATUS_DISCONNECTED;
			break;
		case WSAECONNRESET:
			/* The other side of the connection disconnected */
			if (conn->connType == NET_CONN_SERVER)
			{
				Msg_Warning("Client disconnected");
			}
			else
			{
				Msg_Warning("Connection to server closed");
				conn->status = NET_STATUS_DISCONNECTED;
			}
			break;
		default:
			Msg_Warning("Failed to recieve data: %d", err);
			break;
		}
		return 0;
	}
}
