/* $USAGI: spd.c,v 1.22 2002/12/13 12:44:03 mk Exp $ */
/*
 * Copyright (C)2001 USAGI/WIDE Project
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 * Authors:
 *   Kazunori MIYAZAWA <miyazawa@linux-ip.org> / USAGI
 *   Mitsuru KANDA <mk@linux-ipv6.org> / USAGI
 *
 * Acknowledgements:
 *   Joy Latten <latten@austin.ibm.com>
 */
/*
 * spd.c provide manipulatoin routines for IPsec SPD.
 * struct ipsec_sp represent a policy in IPsec SPD.
 * struct ipsec_sp refers IPsec SA by struct sa_index.
 */

#include <linux/config.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/version.h>
#include <linux/init.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/in6.h>
#include <linux/slab.h>
#include <linux/spinlock.h>
#include <linux/list.h>
#include <linux/socket.h>
#include <linux/skbuff.h>
#include <linux/ipsec.h>
#include <linux/pfkeyv2.h>
#include <net/spd.h>
#include <net/sadb.h>
#include "sockaddr_utils.h"

#ifdef CONFIG_PROC_FS
#include <linux/proc_fs.h>
#endif /* CONFIG_PROC_FS */

#define BUFSIZE 64

/* spd_list : IPsec Security Policy Database(SPD)
 * spd_lock : lock for SPD
 */
LIST_HEAD(spd_list);
rwlock_t spd_lock = RW_LOCK_UNLOCKED;

#define PROTO_ANY_OFF	0
#define PROTO_ANY_ON	1

/*
 * proto_any == PROTO_ANY_ON means 'match protocol'.
 * proto_any == PROTO_ANY_OFF means 'dont care protocol'.
 *
 * PROTO_ANY_ON is used to match policies and packets in ipsec_sp().
 * PROTO_ANY_OFF is used to process pfkey messages in spd_find_by_selector() and spd_remove().
 *
 */

static int ipsec_selector_compare(struct selector *selector1, struct selector *selector2, u8 proto_any)
{
	int tmp;

	if (!(selector1&&selector2)) {
		SPD_DEBUG("selector1 or selecotr2 is NULL\n");
		return -EINVAL;
	}

	if ( (proto_any && selector1->proto != selector2->proto) ||
			((selector1->proto && selector2->proto) && (selector1->proto != selector2->proto)) )  {
				SPD_DEBUG("unmatch: selector->proto\n");
				SPD_DEBUG("  selector1->proto:%d\n", selector1->proto);
				SPD_DEBUG("  selector2->proto:%d\n", selector2->proto);
				return -EINVAL;
	}

#ifdef CONFIG_IPSEC_TUNNEL
	tmp = !(selector1->mode == selector2->mode);
	if (tmp) {
		SPD_DEBUG("unmatch: selector->mode\n");
		return (tmp);
	}
#endif

	tmp = sockaddr_prefix_compare((struct sockaddr*)&selector1->src, selector1->prefixlen_s,
				      (struct sockaddr*)&selector2->src, selector2->prefixlen_s) ||
	      sockaddr_prefix_compare((struct sockaddr*)&selector1->dst, selector1->prefixlen_d,
				      (struct sockaddr*)&selector2->dst, selector2->prefixlen_d);

	/* tmp == 0 means successful match so far */
	if (tmp) {
		SPD_DEBUG("unmatch:sockaddr_prefix_compare\n");
		return (tmp);
	}

	/* compare ports, if they are set */
	tmp = sockaddr_compare_ports((struct sockaddr*)&selector1->src, (struct sockaddr*)&selector2->src);
	if (tmp) {
		SPD_DEBUG("unmatch:sockaddr_compare_port s\n");
		return (tmp);
	}
	tmp = sockaddr_compare_ports((struct sockaddr*)&selector1->dst, (struct sockaddr*)&selector2->dst);
	if (tmp) {
		SPD_DEBUG("unmatch:sockaddr_compare_port d\n");
		return (tmp);
	}

	return 0;       /* everything matches */

}
struct ipsec_sp *ipsec_sp_kmalloc()
{
	struct ipsec_sp *sp = NULL;

	sp = (struct ipsec_sp *)kmalloc(sizeof(struct ipsec_sp), GFP_KERNEL);

	if (!sp) {
		SPD_DEBUG("entry couldn\'t be allocated.\n");
		return NULL;
	}

	ipsec_sp_init(sp);

	return sp;
}

int ipsec_sp_init(struct ipsec_sp *policy)
{
	if (!policy) {
		SPD_DEBUG("policy is NULL\n");
		return -EINVAL;
	}

	memset(policy, 0, sizeof(struct ipsec_sp));
	policy->auth_sa_idx = NULL;
	policy->esp_sa_idx = NULL;
	policy->comp_sa_idx = NULL;
	atomic_set(&policy->refcnt,1);
	policy->lock = RW_LOCK_UNLOCKED;

	return 0;
}

void ipsec_sp_kfree(struct ipsec_sp *policy)
{
	if (!policy) {
		SPD_DEBUG("entry is NULL\n");
		return;
	}

	if (atomic_read(&policy->refcnt)) {
		SPD_DEBUG("policy has been referenced\n");
		return;
	}

	if (policy->auth_sa_idx) sa_index_kfree(policy->auth_sa_idx);
	if (policy->esp_sa_idx) sa_index_kfree(policy->esp_sa_idx);
	if (policy->comp_sa_idx) sa_index_kfree(policy->comp_sa_idx);

	kfree(policy);
}

int ipsec_sp_copy(struct ipsec_sp *dst, struct ipsec_sp *src)
{
	int error = 0;

	if (!dst || !src) {
		SPD_DEBUG("dst or src is NULL\n");
		error = -EINVAL;
		goto err;
	}

	memcpy(&dst->selector, &src->selector, sizeof(struct selector));

	if (dst->auth_sa_idx) sa_index_kfree(dst->auth_sa_idx);
	if (dst->esp_sa_idx) sa_index_kfree(dst->esp_sa_idx);
	if (dst->comp_sa_idx) sa_index_kfree(dst->comp_sa_idx);

	if (src->auth_sa_idx) {
		dst->auth_sa_idx = sa_index_kmalloc();
		memcpy(dst->auth_sa_idx, src->auth_sa_idx, sizeof(struct sa_index));
	}

	if (src->esp_sa_idx) {
		dst->esp_sa_idx = sa_index_kmalloc();
		memcpy(dst->esp_sa_idx, src->esp_sa_idx, sizeof(struct sa_index));
	}

	if (src->comp_sa_idx) {
		dst->comp_sa_idx = sa_index_kmalloc();
		memcpy(dst->comp_sa_idx, src->comp_sa_idx, sizeof(struct sa_index));
	}

	dst->policy_action = src->policy_action;

	atomic_set(&dst->refcnt, 1);
err:
	return error;
}

int ipsec_sp_put(struct ipsec_sp *policy)
{
	int error = 0;

	if (!policy) {
		SPD_DEBUG("policy is NULL\n");
		error = -EINVAL;
		goto err;
	}

	write_lock_bh(&policy->lock);
	SPD_DEBUG("ptr=%p,refcnt=%d\n",
			policy, atomic_read(&policy->refcnt));

	if (atomic_dec_and_test(&policy->refcnt)) {

		SPD_DEBUG("ptr=%p,refcnt=%d\n",
			policy, atomic_read(&policy->refcnt));

		write_unlock_bh(&policy->lock);

		ipsec_sp_kfree(policy);

		return 0;
	}

	write_unlock_bh(&policy->lock);

err:
	return error;
}

void ipsec_sp_release_invalid_sa(struct ipsec_sp *policy, struct ipsec_sa *sa)
{
	if (!policy) {
		SPD_DEBUG("policy is NULL\n");
		return;
	}

	if (policy->auth_sa_idx && policy->auth_sa_idx->sa == sa) {
		ipsec_sa_put(policy->auth_sa_idx->sa);
		policy->auth_sa_idx->sa = NULL;
	}

	if (policy->esp_sa_idx && policy->esp_sa_idx->sa == sa) {
		ipsec_sa_put(policy->esp_sa_idx->sa);
		policy->esp_sa_idx->sa = NULL;
	}

	if (policy->comp_sa_idx && policy->comp_sa_idx->sa == sa) {
		ipsec_sa_put(policy->comp_sa_idx->sa);
		policy->comp_sa_idx->sa = NULL;
	}
}

struct ipsec_sp* ipsec_sp_get(struct selector *selector)
{
	struct list_head *pos = NULL;
	struct ipsec_sp *tmp_sp = NULL;

	if (!selector) {
		SPD_DEBUG("selector is NULL\n");
		goto err;
	}
	
	read_lock(&spd_lock);
	list_for_each(pos, &spd_list){
		tmp_sp = list_entry(pos, struct ipsec_sp, entry);
		read_lock_bh(&tmp_sp->lock);
		if (!ipsec_selector_compare(selector, &tmp_sp->selector, PROTO_ANY_OFF)) {
			SPD_DEBUG("found matched element\n");
			atomic_inc(&tmp_sp->refcnt);
			read_unlock_bh(&tmp_sp->lock);
			break;
		}
		read_unlock_bh(&tmp_sp->lock);
		tmp_sp = NULL;
	}
	read_unlock(&spd_lock);
	
err:
	return tmp_sp;
}

int spd_append(struct ipsec_sp *policy)
{
	int error = 0;
	struct ipsec_sp *new = NULL;

	if (!policy) {
		SPD_DEBUG("policy is NULL\n");
		error = -EINVAL;
		goto err;
	}

	new = ipsec_sp_kmalloc();
	if (!new) {
		SPD_DEBUG("ipsec_sp_kmalloc failed\n");
		error = -ENOMEM;
		goto err;
	}

	error = ipsec_sp_init(new);
	if (error) {
		SPD_DEBUG("ipsec_sp_init failed\n");
		goto err;
	}

	error = ipsec_sp_copy(new, policy);
	if (error) {
		SPD_DEBUG("ipsec_sp_copy failed\n");
		goto err;
	}

	write_lock_bh(&spd_lock);
	list_add_tail(&new->entry, &spd_list);
	write_unlock_bh(&spd_lock);
err:
	return error;
}

int spd_remove(struct selector *selector)
{
	int error = -ESRCH;
	struct list_head *pos = NULL;
	struct list_head *next = NULL;
	struct ipsec_sp *tmp_sp = NULL;

	if (!selector) {
		SPD_DEBUG("selector is NULL\n");
		error = -EINVAL;
		goto err;
	}

	write_lock_bh(&spd_lock);
	list_for_each_safe(pos, next, &spd_list){
		tmp_sp = list_entry(pos, struct ipsec_sp, entry);
		write_lock_bh(&tmp_sp->lock);
		if (!ipsec_selector_compare(selector, &tmp_sp->selector, PROTO_ANY_ON)) {
			SPD_DEBUG("found matched element\n");
			error = 0;
			list_del(&tmp_sp->entry);
			write_unlock_bh(&tmp_sp->lock);
			ipsec_sp_put(tmp_sp);
			break;
		}
		write_unlock_bh(&tmp_sp->lock);
	}
	write_unlock_bh(&spd_lock);

err:
	SPD_DEBUG("error = %d\n", error);
	return error;
}


int spd_find_by_selector(struct selector *selector, struct ipsec_sp **policy)
{
	int error = -ESRCH;
	struct list_head *pos = NULL;
	struct ipsec_sp *tmp_sp = NULL;

	if (!selector) {
		SPD_DEBUG("selector is NULL\n");
		error = -EINVAL;
		goto err;
	}
	
	read_lock(&spd_lock);
	list_for_each(pos, &spd_list){
		tmp_sp = list_entry(pos, struct ipsec_sp, entry);
		read_lock_bh(&tmp_sp->lock);
		if (!ipsec_selector_compare(selector, &tmp_sp->selector, PROTO_ANY_ON)) {
			SPD_DEBUG("found matched element\n");
			error = -EEXIST;
			*policy = tmp_sp;
			atomic_inc(&(*policy)->refcnt);
			read_unlock_bh(&tmp_sp->lock);
			break;
		}
		read_unlock_bh(&tmp_sp->lock);
	}
	read_unlock(&spd_lock);
	

err:
	return error;
}

void spd_clear_db()
{
	struct list_head *pos;
	struct list_head *next;
	struct ipsec_sp *policy;

	write_lock_bh(&spd_lock);
	list_for_each_safe(pos, next, &spd_list){
		policy = list_entry(pos, struct ipsec_sp, entry);
		list_del(&policy->entry);
		ipsec_sp_kfree(policy);		
	}
	write_unlock_bh(&spd_lock);
}


#ifdef CONFIG_PROC_FS
static int spd_get_info(char *buffer, char **start, off_t offset, int length)
{
	int error = 0;
	int count = 0;
	int len = 0;
        off_t pos=0;
        off_t begin=0;
        char buf[BUFSIZE]; 
        struct list_head *list_pos = NULL;
        struct ipsec_sp *tmp_sp = NULL;
	
        read_lock_bh(&spd_lock);
        list_for_each(list_pos, &spd_list){
		count = 0;
                tmp_sp = list_entry(list_pos, struct ipsec_sp, entry);
		read_lock_bh(&tmp_sp->lock);

		len += sprintf(buffer + len, "spd:%p\n", tmp_sp);
                memset(buf, 0, BUFSIZE);
                sockaddrtoa((struct sockaddr*)&tmp_sp->selector.src, buf, BUFSIZE);
                len += sprintf(buffer + len, "%s/%u ", buf, tmp_sp->selector.prefixlen_s);
		sockporttoa((struct sockaddr *)&tmp_sp->selector.src, buf, BUFSIZE);
		len += sprintf(buffer + len, "%s ", buf);
                memset(buf, 0, BUFSIZE);
                sockaddrtoa((struct sockaddr*)&tmp_sp->selector.dst, buf, BUFSIZE);
                len += sprintf(buffer + len, "%s/%u ", buf, tmp_sp->selector.prefixlen_d);
		sockporttoa((struct sockaddr *)&tmp_sp->selector.dst, buf, BUFSIZE);
		len += sprintf(buffer + len, "%s ", buf);
		len += sprintf(buffer + len, "%u ", tmp_sp->selector.proto);
#ifdef CONFIG_IPSEC_TUNNEL
		len += sprintf(buffer + len, "%u ", tmp_sp->selector.mode);
#endif
		len += sprintf(buffer + len, "%u\n", tmp_sp->policy_action);

		if (tmp_sp->auth_sa_idx) {
			len += sprintf(buffer + len, "sa(ah):%p ", tmp_sp->auth_sa_idx->sa);
			sockaddrtoa((struct sockaddr*)&tmp_sp->auth_sa_idx->dst, buf, BUFSIZE);
			len += sprintf(buffer + len, "%s/%d ", buf, tmp_sp->auth_sa_idx->prefixlen_d);
			len += sprintf(buffer + len, "%u ",  tmp_sp->auth_sa_idx->ipsec_proto);
			len += sprintf(buffer + len, "0x%x\n", htonl(tmp_sp->auth_sa_idx->spi));
		}

		if (tmp_sp->esp_sa_idx) {
			len += sprintf(buffer + len, "sa(esp):%p ", tmp_sp->esp_sa_idx->sa);
			sockaddrtoa((struct sockaddr*)&tmp_sp->esp_sa_idx->dst, buf, BUFSIZE);
			len += sprintf(buffer + len, "%s/%d ", buf, tmp_sp->esp_sa_idx->prefixlen_d);
			len += sprintf(buffer + len, "%u ",  tmp_sp->esp_sa_idx->ipsec_proto);
			len += sprintf(buffer + len, "0x%x\n", htonl(tmp_sp->esp_sa_idx->spi));
		}

		if (tmp_sp->comp_sa_idx) {
			len += sprintf(buffer + len, "sa(comp):%p ", tmp_sp->comp_sa_idx->sa);
			sockaddrtoa((struct sockaddr*)&tmp_sp->comp_sa_idx->dst, buf, BUFSIZE);
			len += sprintf(buffer + len, "%s/%d ", buf, tmp_sp->comp_sa_idx->prefixlen_d);
			len += sprintf(buffer + len, "%u ",  tmp_sp->comp_sa_idx->ipsec_proto);
			len += sprintf(buffer + len, "0x%x\n", htonl(tmp_sp->comp_sa_idx->spi));
		}

		read_unlock_bh(&tmp_sp->lock);
		len += sprintf(buffer + len, "\n");

                pos=begin+len;
                if (pos<offset) {
                        len=0;
                        begin=pos;
                }
                if (pos>offset+length) {
                        read_unlock_bh(&spd_lock);
                        goto done;
                }
        }       
        read_unlock_bh(&spd_lock);
done:

        *start=buffer+(offset-begin);
        len-=(offset-begin);
        if (len>length)
                len=length;
        if (len<0)
                len=0;
        return len;

	goto err;
err:
	return error;
}
#endif /* CONFIG_PROC_FS */

int spd_init(void)
{
        int error = 0;

	INIT_LIST_HEAD(&spd_list);
	SPD_DEBUG("spd_list.prev=%p\n", spd_list.prev);
	SPD_DEBUG("spd_list.next=%p\n", spd_list.next);
#ifdef CONFIG_PROC_FS
        proc_net_create("spd", 0, spd_get_info);
#endif /* CONFIG_PROC_FS */

	pr_info("IPsec Security Policy Database (SPD): initialized.\n");
        return error;
}

int spd_cleanup(void)
{
        int error = 0;

#ifdef CONFIG_PROC_FS
        proc_net_remove("spd");
#endif /* CONFIG_PROC_FS */

        spd_clear_db();

	pr_info("IPsec SPD: cleaned up.\n");
        return error;
}

