]> pilppa.org Git - linux-2.6-omap-h63xx.git/blobdiff - net/ipv4/udp.c
[XFRM]: Allow packet drops during larval state resolution.
[linux-2.6-omap-h63xx.git] / net / ipv4 / udp.c
index 66026df1cc7639bcba7f9b0232b8a4267107e7b4..4c7e95fa090d181234e3dbb1a2d1934a259c317f 100644 (file)
@@ -118,15 +118,15 @@ static int udp_port_rover;
  * Note about this hash function :
  * Typical use is probably daddr = 0, only dport is going to vary hash
  */
-static inline unsigned int hash_port_and_addr(__u16 port, __be32 addr)
+static inline unsigned int udp_hash_port(__u16 port)
 {
-       addr ^= addr >> 16;
-       addr ^= addr >> 8;
-       return port ^ addr;
+       return port;
 }
 
 static inline int __udp_lib_port_inuse(unsigned int hash, int port,
-       __be32 daddr, struct hlist_head udptable[])
+                                      const struct sock *this_sk,
+                                      struct hlist_head udptable[],
+                                      const struct udp_get_port_ops *ops)
 {
        struct sock *sk;
        struct hlist_node *node;
@@ -138,7 +138,10 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
                inet = inet_sk(sk);
                if (inet->num != port)
                        continue;
-               if (inet->rcv_saddr == daddr)
+               if (this_sk) {
+                       if (ops->saddr_cmp(sk, this_sk))
+                               return 1;
+               } else if (ops->saddr_any(sk))
                        return 1;
        }
        return 0;
@@ -151,12 +154,11 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
  *  @snum:        port number to look up
  *  @udptable:    hash list table, must be of UDP_HTABLE_SIZE
  *  @port_rover:  pointer to record of last unallocated port
- *  @saddr_comp:  AF-dependent comparison of bound local IP addresses
+ *  @ops:         AF-dependent address operations
  */
 int __udp_lib_get_port(struct sock *sk, unsigned short snum,
                       struct hlist_head udptable[], int *port_rover,
-                      int (*saddr_comp)(const struct sock *sk1,
-                                        const struct sock *sk2 )    )
+                      const struct udp_get_port_ops *ops)
 {
        struct hlist_node *node;
        struct hlist_head *head;
@@ -176,8 +178,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
                for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
                        int size;
 
-                       hash = hash_port_and_addr(result,
-                                       inet_sk(sk)->rcv_saddr);
+                       hash = ops->hash_port_and_rcv_saddr(result, sk);
                        head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
                        if (hlist_empty(head)) {
                                if (result > sysctl_local_port_range[1])
@@ -203,17 +204,16 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
                                result = sysctl_local_port_range[0]
                                        + ((result - sysctl_local_port_range[0]) &
                                           (UDP_HTABLE_SIZE - 1));
-                       hash = hash_port_and_addr(result, 0);
+                       hash = udp_hash_port(result);
                        if (__udp_lib_port_inuse(hash, result,
-                                                0, udptable))
+                                                NULL, udptable, ops))
                                continue;
-                       if (!inet_sk(sk)->rcv_saddr)
+                       if (ops->saddr_any(sk))
                                break;
 
-                       hash = hash_port_and_addr(result,
-                                       inet_sk(sk)->rcv_saddr);
+                       hash = ops->hash_port_and_rcv_saddr(result, sk);
                        if (! __udp_lib_port_inuse(hash, result,
-                               inet_sk(sk)->rcv_saddr, udptable))
+                                                  sk, udptable, ops))
                                break;
                }
                if (i >= (1 << 16) / UDP_HTABLE_SIZE)
@@ -221,7 +221,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
 gotit:
                *port_rover = snum = result;
        } else {
-               hash = hash_port_and_addr(snum, 0);
+               hash = udp_hash_port(snum);
                head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
 
                sk_for_each(sk2, node, head)
@@ -231,12 +231,11 @@ gotit:
                            (!sk2->sk_reuse || !sk->sk_reuse) &&
                            (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
                             sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                           (*saddr_comp)(sk, sk2))
+                           ops->saddr_cmp(sk, sk2))
                                goto fail;
 
-               if (inet_sk(sk)->rcv_saddr) {
-                       hash = hash_port_and_addr(snum,
-                                                 inet_sk(sk)->rcv_saddr);
+               if (!ops->saddr_any(sk)) {
+                       hash = ops->hash_port_and_rcv_saddr(snum, sk);
                        head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
 
                        sk_for_each(sk2, node, head)
@@ -248,7 +247,7 @@ gotit:
                                     !sk->sk_bound_dev_if ||
                                     sk2->sk_bound_dev_if ==
                                     sk->sk_bound_dev_if) &&
-                                   (*saddr_comp)(sk, sk2))
+                                   ops->saddr_cmp(sk, sk2))
                                        goto fail;
                }
        }
@@ -266,12 +265,12 @@ fail:
 }
 
 int udp_get_port(struct sock *sk, unsigned short snum,
-                       int (*scmp)(const struct sock *, const struct sock *))
+                const struct udp_get_port_ops *ops)
 {
-       return  __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp);
+       return  __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, ops);
 }
 
-int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
+static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
 {
        struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
 
@@ -280,9 +279,33 @@ int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
                   inet1->rcv_saddr == inet2->rcv_saddr      ));
 }
 
+static int ipv4_rcv_saddr_any(const struct sock *sk)
+{
+       return !inet_sk(sk)->rcv_saddr;
+}
+
+static inline unsigned int ipv4_hash_port_and_addr(__u16 port, __be32 addr)
+{
+       addr ^= addr >> 16;
+       addr ^= addr >> 8;
+       return port ^ addr;
+}
+
+static unsigned int ipv4_hash_port_and_rcv_saddr(__u16 port,
+                                                const struct sock *sk)
+{
+       return ipv4_hash_port_and_addr(port, inet_sk(sk)->rcv_saddr);
+}
+
+const struct udp_get_port_ops udp_ipv4_ops = {
+       .saddr_cmp = ipv4_rcv_saddr_equal,
+       .saddr_any = ipv4_rcv_saddr_any,
+       .hash_port_and_rcv_saddr = ipv4_hash_port_and_rcv_saddr,
+};
+
 static inline int udp_v4_get_port(struct sock *sk, unsigned short snum)
 {
-       return udp_get_port(sk, snum, ipv4_rcv_saddr_equal);
+       return udp_get_port(sk, snum, &udp_ipv4_ops);
 }
 
 /* UDP is nearly always wildcards out the wazoo, it makes no sense to try
@@ -297,8 +320,8 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport,
        unsigned int hash, hashwild;
        int score, best = -1, hport = ntohs(dport);
 
-       hash = hash_port_and_addr(hport, daddr);
-       hashwild = hash_port_and_addr(hport, 0);
+       hash = ipv4_hash_port_and_addr(hport, daddr);
+       hashwild = udp_hash_port(hport);
 
        read_lock(&udp_hash_lock);
 
@@ -1198,8 +1221,8 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb,
        struct sock *sk, *skw, *sknext;
        int dif;
        int hport = ntohs(uh->dest);
-       unsigned int hash = hash_port_and_addr(hport, daddr);
-       unsigned int hashwild = hash_port_and_addr(hport, 0);
+       unsigned int hash = ipv4_hash_port_and_addr(hport, daddr);
+       unsigned int hashwild = udp_hash_port(hport);
 
        dif = skb->dev->ifindex;