/*
 **************************************************************************
 *
 * Boot-ROM-Code to load an operating system across a TCP/IP network.
 *
 * Module:  resolve.c
 * Purpose: Implement a simple DNS name resolver
 * Entries: resolve, res_config, init_res
 *
 **************************************************************************
 *
 * Copyright (C) 1995-1998 Gero Kuhlmann <gero@gkminix.han.de>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */


#include <general.h>
#include <net.h>
#include <arpa.h>
#include <romlib.h>
#include "./arpapriv.h"
#include "./resolve.h"



/*
 **************************************************************************
 * 
 * Local error codes:
 */
#define ERR_NOERR	0		/* no error		*/
#define ERR_TIMEOUT	1		/* timeout error	*/
#define ERR_INVALID	2		/* invalid packet	*/
#define ERR_SERVER	3		/* server error		*/



/*
 **************************************************************************
 * 
 * Global variables:
 */
#ifndef NODNS
static struct dnshdr *dnsbuf;		/* DNS packet buffer		*/
static unsigned char *namebuf;		/* temporary name buffer	*/
static unsigned char *reqbuf;		/* temporary request buffer	*/
static unsigned char *mydomain;		/* my own domain name		*/
static t_ipaddr servers[MAX_NS];	/* known name servers		*/
static t_ipaddr ip_last = IP_ANY;	/* last IP address from server	*/
static int ns_num;			/* number of name servers	*/
static int dns_xid;			/* Transaction ID		*/
static int dns_len;			/* length of data in dnsbuf	*/
static int req_len;			/* length of name in reqbuf	*/
static recv_err;			/* receive error		*/
#endif



/*
 **************************************************************************
 * 
 * Convert an internet address in ascii form into binary network form
 */
static t_ipaddr conv_ip(name)
char *name;
{
  register char *cp = name;
  unsigned char ip[4];
  int dots;
  int num, i;

  dots = 0;
  while (dots < 4) {
	i = 0;
	num = 0;
	while (*cp >= '0' && *cp <= '9') {
		num = num * 10 + (*cp++ - '0');
		i++;
	}
	if (i == 0 || num > 255 ||
	    (dots < 3 && *cp != '.') ||
	    (dots == 3 && *cp != '\0' && *cp != ':'))
		return(IP_ANY);
	ip[dots++] = num;
	cp++;
  }
  return(ntohl(*((t_ipaddr *)ip)));
}



#ifndef NODNS
/*
 **************************************************************************
 * 
 * Copy a domain name into the request buffer in label form
 */
static int copy_name(buf, name)
unsigned char **buf;
unsigned char *name;
{
  register unsigned char *cp;
  unsigned char c;
  unsigned char *lenp;
  unsigned char *bp;
  int dots = 0;
  int i;

  lenp = *buf;
  bp = (*buf) + 1;
  if (name == NULL)
	return(-1);

  for (cp = name, i = 0; *cp && *cp != ':'; cp++) {
	if (bp >= reqbuf + MAX_NAME_LEN - 1)
		return(-1);
	if (*cp == '.') {
		if (i == 0 || i > MAX_LABEL_LEN) return(-1);
		dots++;
		*lenp = i;
		lenp = bp++;
		i = 0;
	} else {
		c = *cp;
		if (c >= 'A' && c <= 'Z')
			c += 32;
		if ((c < 'a' || c > 'z') &&
		    (i == 0 || ((c < '0' || c > '9') && c != '-')))
			return(-1);
		*bp++ = c;
		i++;
	}
  }

  if (i == 0 || i > MAX_LABEL_LEN)
	return(-1);
  *lenp = i;
  *buf = bp;
  return(dots);
}



/*
 **************************************************************************
 * 
 * Create a request domain name in labelized form. This also checks that
 * the name conforms to the relevant RFC's.
 */
static int create_name(name)
char *name;
{
  unsigned char *rp;
  int i;

  rp = reqbuf;
  if ((i = copy_name(&rp, (unsigned char *)name)) < 0 ||
      (i == 0 && copy_name(&rp, mydomain) < 0))
	return(FALSE);
  *rp++ = '\0';			/* this zero byte represents the root domain */
  req_len = rp - reqbuf;
  return(TRUE);
}



#ifdef NSDEBUG
/*
 **************************************************************************
 * 
 * Print a labelized name
 */
static void print_name(cp)
unsigned char *cp;
{
  int i;

  while (*cp) {
	i = *cp++;
	printf("%ls.", cp, i);
	cp += i;
  }
}
#endif



/*
 **************************************************************************
 * 
 * Send a query record
 */
static void send_query()
{
  register unsigned short *bp;
  register struct dnshdr *dp = dnsbuf;
  int i;

  /*
   * Setup the request header. Note that this is a rather simple resolver
   * which requires the server to support recursion if it doesn't have an
   * authoratative answer. By setting the recursion flag we avoid doing
   * all the nameserver and alias lookup stuuf ourselves. However, a draw-
   * back is that some name servers don't support recursion or have it
   * intentionally disabled...
   */
  memset(dp, 0, DNS_UDP_LEN);
  dp->xid = htons(dns_xid);
  dp->flags = htons(HEADER_RD | OPCODE_QUERY);
  dp->qdcount = htons(1);

  /*
   * Setup the request record by copying the labelized name and type and class
   * values into the send buffer.
   */
  bp = (unsigned short *)((unsigned char *)dp + sizeof(struct dnshdr));
  memcpy((unsigned char *)bp, reqbuf, req_len);
  bp = (unsigned short *)((unsigned char *)bp + req_len);
  *bp++ = htons(DNS_TYPE_A);
  *bp++ = htons(DNS_CLASS_IN);
#ifdef NSDEBUG
  printf("Sending DNS request for ");
  print_name(reqbuf);
  printf("\n");
#endif

  /* Finally send the request */
  i = (unsigned char *)bp - (unsigned char *)dp;
  (void)udp_write((char *)dp, i);
}



/*
 **************************************************************************
 * 
 * Decode a string in dnsbuf to labelized form while uncompressing it
 */
static unsigned char *ndecode(name)
unsigned char *name;
{
  register unsigned char *endp = NULL;
  register unsigned char *cp = name;
  unsigned char *bp;
  int i;

  bp = namebuf;
  do {
	if ((*cp & COMPR_MASK) == COMPR_MASK) {
		if (endp == NULL)
			endp = cp + 2;
		i = ntohs(*((unsigned short *)cp)) & OFFSET_MASK;
		cp = (unsigned char *)dnsbuf + i;
	} else for (i = *cp + 1; i > 0 && bp < namebuf + MAX_NAME_LEN; i--)
		*bp++ = *cp++;
	if (cp >= (unsigned char *)dnsbuf + dns_len)
		return(NULL);
  } while (*cp);

  return(endp == NULL ? ++cp : endp);
}



/*
 **************************************************************************
 * 
 * Decode a received DNS packet
 */
static t_ipaddr recv_answer()
{
#define rpp ((struct rr *)cp)

  register unsigned char *cp;
  int flags, recnum;

  /* Wait for a UDP packet and return with error if timeout */
  dns_len = udp_read((char *)dnsbuf, DNS_UDP_LEN, DNS_TIMEOUT, 0);
  if (dns_len == 0) {
	recv_err = ERR_TIMEOUT;
	return(IP_ANY);
  }

  /* Check for correct header and simply skip invalid packets */
  flags = ntohs(dnsbuf->flags);
  if (dns_len < 0 || dns_len > DNS_UDP_LEN ||
      (flags & HEADER_QR) != HEADER_QR ||
      (flags & HEADER_OPCODE) != OPCODE_QUERY ||
      ntohs(dnsbuf->xid) != dns_xid) {
	recv_err = ERR_INVALID;
	return(IP_ANY);
  }

  /* Decode the packet */
  if ((flags & HEADER_RCODE) == RCODE_NOERR) {
	/* Skip all query records */
	cp = (unsigned char *)dnsbuf + sizeof(struct dnshdr);
	if (ntohs(dnsbuf->qdcount) != 1 || (cp = ndecode(cp)) == NULL ||
	    memcmp(reqbuf, namebuf, req_len)) {
		recv_err = ERR_INVALID;
		return(IP_ANY);
	}
	cp += 4;

	/* Decode the first address record */
	for (recnum = 0; recnum < ntohs(dnsbuf->ancount); recnum++) {
		if ((cp = ndecode(cp)) == NULL) {
			recv_err = ERR_INVALID;
			return(IP_ANY);
		}
#ifdef NSDEBUG
		printf("Received DNS answer for ");
		print_name(namebuf);
		printf("\n");
#endif
		if (rpp->type == htons(DNS_TYPE_A) &&
		    rpp->class == htons(DNS_CLASS_IN) &&
		    rpp->rdlength >= htons(IP_ALEN) &&
		    !memcmp(reqbuf, namebuf, req_len)) {
			recv_err = ERR_NOERR;
			return(ntohl(*((t_ipaddr *)(cp + sizeof(struct rr)))));
		}
		cp += ntohs(rpp->rdlength) + sizeof(struct rr);
	}
  }

  /* Decode error condition */
  flags &= HEADER_RCODE;
  recv_err = ERR_SERVER;
  if (flags != RCODE_NAME) {
	if (flags == RCODE_NOERR)
		recv_err = ERR_INVALID;
	else
		printf("DNS: server error %x\n", flags);
  }
  return(IP_ANY);

#undef rrp
}
#endif /* NODNS */



/*
 **************************************************************************
 * 
 * Resolve a host name using DNS
 */
t_ipaddr resolve(name)
char *name;
{
  t_ipaddr ip;
#ifndef NODNS
  int ns, retry;
#endif

  /* First check if the name is already in IP number form */
  if ((ip = conv_ip(name)) != IP_ANY)
	return(ip);

#ifndef NODNS
  /* For local addresses we don't have to go through the resolver */
  if (!memcmp(name, "localhost", 9))
	return(IP_LOCALHOST);
  else if (!memcmp(name, "broadcast", 9))
	return(IP_BROADCAST);

  /*
   * If there are no domain name servers initialized, we can't continue. Other-
   * wise create a labelized domain name to resolve.
   */
  if (ns_num == 0 || !create_name(name)) {
	printf("DNS: invalid host name\n");
	return(IP_ANY);
  }

  /*
   * If this request is for the same name as the previous request, we can
   * just return the IP number and don't need to ask the server.
   */
  if (req_len > 0 && ip_last != IP_ANY && !memcmp(reqbuf, namebuf, req_len))
	return(ip_last);

  /* Now query each name server in turn */
  for (ns = 0; ns < ns_num; ns++) {
	dns_xid = (int)(get_ticks() + random());
	if (!udp_open(servers[ns], DNS_C_PORT, DNS_S_PORT)) {
		printf("DNS: ARP timeout\n");
		break;
	}
	retry = 0;
	while (retry++ < DNS_RETRY) {
		send_query();
		if ((ip = recv_answer()) != IP_ANY) {
			ip_last = ip;
			return(ip);
		}
		/* If a server error occurred, retries are useless */
		if (recv_err == ERR_SERVER)
			break;
	}
	if (recv_err == ERR_TIMEOUT)
		printf("DNS: timeout\n");
	else if (recv_err == ERR_INVALID)
		printf("DNS: invalid packets\n");
  }
#endif

  printf("DNS: can't resolve host name\n");
  return(IP_ANY);
}



/*
 **************************************************************************
 * 
 * Initialize name server list and current domain from BOOTP record
 */
void res_config()
{
#ifndef NODNS
  register unsigned char *cp;

  /* Copy name server information */
  if ((cp = get_vend(VEND_DNS)) != NULL) {
	t_ipaddr *ip = (t_ipaddr *)(cp + 1);
	int i = *cp;
	ns_num = 0;
	while (ns_num < MAX_NS && i >= IP_ALEN) {
		servers[ns_num++] = ntohl(*ip++);
		i -= IP_ALEN;
	}
  }

  /* Copy domain name information */
  if ((cp = get_vend(VEND_DOMAIN)) != NULL) {
	int i = *cp++;
	if (i < (MAX_NAME_LEN - 1)) {
		memcpy(mydomain, cp, i);
		mydomain[i] = '\0';
	}
  }
#endif
}



#ifndef NODNS
/*
 **************************************************************************
 * 
 * Initialize name resolver.
 */
int init_resolve()
{
  int i;

  /* Set name of module for error messages */
  arpa_module_name = "resolve";

  /* Allocate space for buffers */
  i = DNS_UDP_LEN + 3 * MAX_NAME_LEN + 2;
  if ((dnsbuf = (struct dnshdr *)malloc(i)) == NULL)
	return(FALSE);
  namebuf = (unsigned char *)dnsbuf + DNS_UDP_LEN;
  reqbuf = namebuf + MAX_NAME_LEN;
  mydomain = reqbuf + MAX_NAME_LEN;

  return(TRUE);
}
#endif
