//
// A virtual class modelling connections.
//
// Written by: Mick Dwyer, 21 Jan, 2001
//
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <time.h>
#include <sys/errno.h>

#include "str.h"
#include "config.h"
#include "scenario.h"
#include "String.h"

#include "connection.h"


#ifdef HAVE_SSL

#include <openssl/err.h>

#endif /* HAVE_SSL */

/*
 * NAME:        fdConnect
 * ACTION:      Connect to the server.
 * RETURNS:     file descriptor if connected, NULL otherwise
 */
int Connection::fdConnect(void)
{
    int SearchFD;
    int GotIt;
	int retry;
    struct sockaddr_in MyAddr, ServerAddr;
    /* open socket */
    SearchFD = socket(PF_INET, SOCK_STREAM, 0);
    if (SearchFD < 0)
    {
        sess->scenario->Log(itoa(sess->pin) + ": Socket creation failed");
        return 0;
    }
    /* Are we connecting from an alternate interface? */
    if (sess->Interface() && sess->Interface()->Ifreq())
    {
#if 1
        /* set up the address */
        memset(&MyAddr, 0, sizeof(MyAddr));
        MyAddr.sin_family = AF_INET;
        MyAddr.sin_addr.s_addr = sess->Interface()->Ipaddr();
        GotIt = bind(SearchFD, (struct sockaddr *)&MyAddr, sizeof(MyAddr));
        if (GotIt == -1)
        {
            sess->scenario->Log(itoa(sess->pin) + ": bind failed (" + itoa(errno) +")");
            close(SearchFD);
            return 0;
        }
#else
        ioctl(SearchFD, SIOCSIFADDR, sess->Interface()->Ifreq());
#endif
    }
    /* set up the address */
    memset(&ServerAddr, 0, sizeof(ServerAddr));
    ServerAddr.sin_family = AF_INET;

	/* setup for the various destination port possiblilities */
	if (sess->scenario->port != 0 ) 
	{
		ServerAddr.sin_port = htons(sess->scenario->port);
	}
	else
	{
    	ServerAddr.sin_port = htons(sess->port);
	}

    /* sess->ip_addr is obtained from inet_addr() which returns
       numbers in network order. */
    ServerAddr.sin_addr.s_addr = sess->ip_addr;

	retry = 0;
	do {
    	/* connect */
    	GotIt = ::connect(SearchFD, (struct sockaddr *)&ServerAddr,
                    sizeof(ServerAddr)) >= 0;

		/* we should do this safely by checking for EINTR on return and retrying if 
       		we were interrupted... also check the Shutdown variable */
		if (!GotIt && errno == EINTR) retry = 1;
		else retry = 0;
	//} while (!Shutdown && !GotIt && retry);
	} while (!GotIt && retry);

    if (sess->Interface() && sess->Interface()->Ifreq())
        ioctl(SearchFD, SIOCSIFDSTADDR, sess->Interface()->Ifreq());

    if (!GotIt)
    {
        if (sess->Interface()) /* we specified the name of the socket */
            fprintf(stderr, "failed to connect (%s) - %s (%d) \n", inet_ntoa(ServerAddr.sin_addr), strerror(errno), errno);
        else
            fprintf(stderr, "failed to connect (%s) - %s  (%d)\n", inet_ntoa(ServerAddr.sin_addr), strerror(errno), errno);
        sess->scenario->Log(itoa(sess->pin) + ": connect failed (" + itoa(errno) +" - " + strerror(errno) + ")");
        close(SearchFD);
        return 0;
    }
    return SearchFD;
}   

/*
 *
 * connection in the clear.
 *
 */

char * Clear_Connection::gets(char *b, int n)
{
	return fgets(b, n, f);
}

int Clear_Connection::read(char *b, int n)
{
	return fread(b, 1, n, f);
}

int Clear_Connection::write(const char *s, int n)
{
	return fwrite(s, 1, n, f);
}

void Clear_Connection::flush()
{
	fflush(f);
}

Clear_Connection::Clear_Connection(Session *s)
{
	name = new String("Clear");
	sess = s;
	f = NULL;
}

Clear_Connection::~Clear_Connection()
{
	delete name;

	if (f) {	
		close(fileno(f));    /* ok, so I'm anal */
    	fclose(f);
	}
}

/*
 * NAME:        connect
 * ACTION:      Connect to the server.
 */
int Clear_Connection::connect(void)
{
    int RequestFD;

    RequestFD = fdConnect();

    if (RequestFD == 0) return -1;

    /* make it work with stdio */
    f = fdopen(RequestFD, "r+");

    return 0;
}


#ifdef HAVE_SSL
/*
 *
 * SSL connection class.
 *
 */

char * SSL_Connection::gets(char *b, int n)
{
	int len = 0;
	char *p;
	char *s;

	s = b;

	n--; /* make room for the terminating NUL */
	while (n != 0) 
	{
		if (gets_len == 0) /* refill the buffer */
		{
			int res;

			res = SSL_read(ssl, gets_buf, sizeof(gets_buf));

			/* check the result */
			if (!res)
			{
				/* test to see if anything has been read this call to gets */
				if (s == b) /* nothing read */
				{
					//fprintf(stderr, "Empty read.\n");
					return NULL;
				}
				/* something has been read, so break, terminate and return */
				//fprintf(stderr, "Failed to eat data\n");
				break;
			}
			
			gets_b = gets_buf;
			gets_len = res;

			//fprintf(stderr, "Eating data\n");
		}

		//fprintf(stderr, "%p: len = %d, n = %d, gets_buf = %p, gets_b = %p, gets_len = %d", this, len, n, gets_buf, gets_b, gets_len);

		len = gets_len; /* set the available len */

		/* if we have more in the buffer than the callers buffer can hold then only
           search what will fit in the callers buffer. */
		if (len > n) len = n;


		/* find the \n or determine that we can read to EOF or until b is full */
		p = (char *)memchr((void *)gets_b, '\n', len);
		if (p != NULL) /* we found a '\n' in the gets_buf */
		{
			len = ++p - gets_b;
			memcpy((void *)s, (void *)gets_b, len);
			gets_b = p;
			gets_len -= len;
            s[len] = 0;
			//fprintf(stderr, " -- short line\n");
            return b;
		}
		/* didn't find a '\n' so we'll copy everything... */
		memcpy((void *)s, (void *)gets_b, len);
		/* update pointers and lengths */
		gets_b += len;
		gets_len -= len;
		s += len;
		n -= len;
		//fprintf(stderr, "Spinning to fill given buffer.\n");
	}
	//fprintf(stderr, " -- long line\n");
	*s = 0;
	return b;
}

int SSL_Connection::read(char *b, int n)
{
	/* check to see if we have any buffered stuff left from gets calls. */
	if (gets_len != 0)
	{
		if (gets_len >= n) /* exactly or excess data for read */
		{
			int l = gets_len;

			if (gets_len > n) l = n;

			memcpy(b, gets_b, l);
			gets_len -= l;
			gets_b += l;
			return l;
		}
		else 
		{
			int l = gets_len;

			memcpy(b, gets_b, l);
			gets_len = 0;
			b += l;
			n -= l;
			return SSL_read(ssl, b, n) + l;
		}
	}
	return SSL_read(ssl, b, n);
}

int SSL_Connection::write(const char *s, int n)
{
	return SSL_write(ssl, s, n);
}

void SSL_Connection::flush()
{
	/* do nothing. Hmmm this could be a source of problems */
}

SSL_Connection::SSL_Connection(Session *s) : gets_len(0)
{
	name = new String("SSL");
	sess = s;

    /* initialise SSL */
    SSL_load_error_strings();
    SSLeay_add_ssl_algorithms(); 

    /* construct the context for secure connections */
    method = SSLv2_client_method();
    context = SSL_CTX_new(method);
    ssl = SSL_new(context);
}

SSL_Connection::~SSL_Connection()
{
	int fd;

	delete name;

	fd = SSL_get_fd(ssl);

	/* shutdown */
	SSL_shutdown (ssl);

	/* free the server certificate */
	X509_free (server_certificate);

	/* release the SSL handle (which also frees the context */
	SSL_free(ssl);
	
	/* release the context */
	SSL_CTX_free(context);

	/* free the error strings */
	ERR_free_strings();

	// finally close the file descriptor 
	close(fd);
}

/*
 * NAME:        connect
 * ACTION:      Connect to the server.
 */
int SSL_Connection::connect(void)
{
    int SSLRequestFD;

	/* get a connected socket */
    SSLRequestFD = fdConnect();

    if (SSLRequestFD == 0) return -1;

    SSL_set_fd(ssl, SSLRequestFD);
    SSL_connect(ssl); /* how about some error checking. */
    server_certificate = SSL_get_peer_certificate(ssl);
	SSL_set_read_ahead(ssl, 1); /* set the read-ahead */

    return 0;
}


#endif /* HAVE_SSL */
