/*
 * copyright (c) 2000 
 * the regents of the university of michigan 
 * all rights reserved 
 *
 * permission is granted to use, copy, create derivative works and  
 * redistribute this software and such derivative works for any purpose,  
 * so long as the name of the university of michigan is not used in any  
 * advertising or publicity pertaining to the use or distribution of this  
 * software without specific, written prior authorization.  if the above  
 * copyright notice or any other identification of the university of  
 * michigan is included in any copy of any portion of this software, then  
 * the disclaimer below must also be included. 
 * 
 * this software is provided as is, without representation from the  
 * university of michigan as to its fitness for any purpose, and without  
 * warranty by the university of michigan of any kind, either express or  
 * implied, including without limitation the implied warranties of  
 * merchantability and fitness for a particular purpose.  the regents of  
 * the university of michigan shall not be liable for any damages,  
 * including special, indirect, incidental, or consequential damages, with  
 * respect to any claim arising out or in connection with the use of the  
 * software, even if it has been or is hereafter advised of the  
 * possibility of such damages.
 */

#include <sys/time.h>		/* timeval, time_t */
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/resource.h>
#include <setjmp.h>		/* jmp_buf et al */
#include <sys/socket.h>		/* basics, SO_ and AF_ defs, sockaddr, ... */
#include <netinet/in.h>		/* sockaddr_in, htons, in_addr */
#include <netinet/in_systm.h>	/* misc crud that netinet/ip.h references */
#include <netinet/ip.h>		/* IPOPT_LSRR, header stuff */
#include <netdb.h>		/* hostent, gethostby*, getservby* */
#include <arpa/inet.h>		/* inet_ntoa */
#include <stdlib.h>
#include <stdio.h>
#include <string.h>		/* strcpy, strchr, yadda yadda */
#include <errno.h>
#include <signal.h>
#include <fcntl.h>		/* O_WRONLY et al */
#include <unistd.h>

#ifndef howmany
#define howmany(x, y)   (((x) + ((y) - 1))/(y))
#endif

int doconnect(struct in_addr addr, short port);

/* Globals */

char *host = "mountainview.citi.umich.edu";
int loadtime = 100;
short port = 8080;
int reps = 100;

void
usage()
{
  fprintf(stderr, "tcp [-h host] [-p port] [-r #conns] [-t time]\n"
	  "\t -h %s\n"
	  "\t -p %d\n"
	  "\t -r %d\n"
	  "\t -t %d\n", host, port, reps, loadtime);
  exit(1);
}

int
main(int argc, char **argv)
{
    int ch, i, *fds, size, reopen;
    struct hostent *hp;
    struct in_addr addr;
    struct rlimit flim;
    fd_set *fderr;
    struct timeval timeout;
    time_t tend, tnow;

    /* Init */
    host = "mountainview.citi.umich.edu";
    loadtime = 100;
    port = 8080;
    reps = 100;

    while ((ch = getopt(argc, argv, "h:p:r:t:")) != -1)
	switch((char)ch) {
	case 'h':
	    host = optarg;
	    break;
	case 'p':
	    port = atoi(optarg);
	    break;
	case 't':
	    loadtime = atoi(optarg);
	    if (loadtime <= 0) {
		fprintf(stderr, "%s: -t requires time > 0\n",
			argv[0]);
		exit (1);
	    }
	    break;
	case 'r':
	    reps = atoi(optarg);
	    if (reps <= 0) {
		fprintf(stderr, "%s: -r requires value > 0\n",
			argv[0]);
		exit (1);
	    }
	    break;
	default:
	    usage();
	}

    /* Tweak our resources */
    if (getrlimit(RLIMIT_NOFILE, &flim) == -1) {
	perror("getrlimit");
	exit(1);
    }
    if (flim.rlim_cur < reps + 3) {
	flim.rlim_cur = reps + 3;
	if (setrlimit(RLIMIT_NOFILE, &flim) == -1) {
	    perror("setrlimit");
	    exit(1);
	}
    }

    /* Resolve the name */
    if (!inet_aton(host, &addr)) {
	hp = gethostbyname(host);
	if (!hp) {
	    fprintf(stderr, "%s: %s not a valid addr\n",
		    argv[0], host);
	    exit(1);
	}
	bcopy(hp->h_addr, &addr, sizeof(addr));
    }

    fds = (int *)calloc(reps, sizeof(int));
    if (!fds) {
	fprintf(stderr, "%s: out of memory\n");
	exit(1);
    }


    fprintf(stdout, "%s: starting %d load-adding connections to %s:%d\n",
	    argv[0], reps, host, port);
    for (i = 0; i < reps; i++)
	if ((fds[i] = doconnect(addr, port)) == -1)
	    break;

    if (i < reps)
	goto out;

    fprintf(stdout, "%s: added load, sleeping for %d seconds\n",
	    argv[0], loadtime);
    size = howmany(fds[reps-1], NFDBITS) * sizeof(fd_set);
    fderr = malloc(size);
    if (!fderr) {
	perror("malloc");
	goto out;
    }
    reopen = 0;
    tnow = time(NULL);
    tend = time(NULL) + loadtime;
    while (time(NULL) < tend) {
	int maxfd = 0;

	bzero(fderr, size);
	for (i = 0; i < reps; i++) {
	    FD_SET(fds[i], fderr);
	    if (fds[i] > maxfd)
		maxfd = fds[i];
	}

	timeout.tv_sec = tend - time(NULL);
	timeout.tv_usec = 0;
	if (select(maxfd + 1, fderr, NULL, NULL, &timeout) == -1) {
	    perror("select");
	    continue;
	}
	for (i = 0; i < reps; i++) {
	    if (FD_ISSET(fds[i], fderr)) {
		close (fds[i]);
		if ((fds[i] = doconnect(addr, port)) == -1) {
		    fds[i] = 0;
		    break;
		}
		reopen++;
	    }
	}

	/* Report when all connections have been reopened */
	if (reopen >= reps) {
	    fprintf(stdout,
		    "%s: reopened %d connections after %d seconds\n",
		    argv[0], reopen, time(NULL) - tnow);
	    reopen = 0;
	    tnow = time(NULL);
	}

	/* Spread things out */
	sleep(1);
    }    

 out:
    for (i = 0; i < reps; i++) {
	if (fds[i] == -1)
	    break;
	close(fds[i]);
    }
    return 1;
}

int
doconnect(struct in_addr addr, short port)
{
  struct sockaddr_in sa;
  int fd;

  fd = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
  if (fd < 0) {
    perror("socket");
    return -1;
  }

  bzero(&sa, sizeof(sa));
  sa.sin_family = AF_INET;
  sa.sin_addr = addr;
  sa.sin_port = htons (port);

  if (connect (fd, (struct sockaddr *)&sa, sizeof (sa)) == -1) {
    perror("connect");
    close (fd);
    return -1;
  }

  return fd;
}
