]> pilppa.org Git - linux-2.6-omap-h63xx.git/blobdiff - net/sunrpc/svcsock.c
[PATCH] knfsd: SUNRPC: Support IPv6 addresses in svc_tcp_accept
[linux-2.6-omap-h63xx.git] / net / sunrpc / svcsock.c
index b11669670baa406614f458633cf1daeac1f5c8b9..72831b8a58fb84f4f875d3e7dff5aef0b4fbab18 100644 (file)
@@ -36,6 +36,7 @@
 #include <net/sock.h>
 #include <net/checksum.h>
 #include <net/ip.h>
+#include <net/ipv6.h>
 #include <net/tcp_states.h>
 #include <asm/uaccess.h>
 #include <asm/ioctls.h>
@@ -446,6 +447,43 @@ svc_wake_up(struct svc_serv *serv)
        }
 }
 
+union svc_pktinfo_u {
+       struct in_pktinfo pkti;
+#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
+       struct in6_pktinfo pkti6;
+#endif
+};
+
+static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh)
+{
+       switch (rqstp->rq_sock->sk_sk->sk_family) {
+       case AF_INET: {
+                       struct in_pktinfo *pki = CMSG_DATA(cmh);
+
+                       cmh->cmsg_level = SOL_IP;
+                       cmh->cmsg_type = IP_PKTINFO;
+                       pki->ipi_ifindex = 0;
+                       pki->ipi_spec_dst.s_addr = rqstp->rq_daddr.addr.s_addr;
+                       cmh->cmsg_len = CMSG_LEN(sizeof(*pki));
+               }
+               break;
+#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
+       case AF_INET6: {
+                       struct in6_pktinfo *pki = CMSG_DATA(cmh);
+
+                       cmh->cmsg_level = SOL_IPV6;
+                       cmh->cmsg_type = IPV6_PKTINFO;
+                       pki->ipi6_ifindex = 0;
+                       ipv6_addr_copy(&pki->ipi6_addr,
+                                       &rqstp->rq_daddr.addr6);
+                       cmh->cmsg_len = CMSG_LEN(sizeof(*pki));
+               }
+               break;
+#endif
+       }
+       return;
+}
+
 /*
  * Generic sendto routine
  */
@@ -455,9 +493,8 @@ svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr)
        struct svc_sock *svsk = rqstp->rq_sock;
        struct socket   *sock = svsk->sk_sock;
        int             slen;
-       char            buffer[CMSG_SPACE(sizeof(struct in_pktinfo))];
+       char            buffer[CMSG_SPACE(sizeof(union svc_pktinfo_u))];
        struct cmsghdr *cmh = (struct cmsghdr *)buffer;
-       struct in_pktinfo *pki = (struct in_pktinfo *)CMSG_DATA(cmh);
        int             len = 0;
        int             result;
        int             size;
@@ -470,21 +507,15 @@ svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr)
        slen = xdr->len;
 
        if (rqstp->rq_prot == IPPROTO_UDP) {
-               /* set the source and destination */
-               struct msghdr   msg;
-               msg.msg_name    = &rqstp->rq_addr;
-               msg.msg_namelen = rqstp->rq_addrlen;
-               msg.msg_iov     = NULL;
-               msg.msg_iovlen  = 0;
-               msg.msg_flags   = MSG_MORE;
-
-               msg.msg_control = cmh;
-               msg.msg_controllen = sizeof(buffer);
-               cmh->cmsg_len = CMSG_LEN(sizeof(*pki));
-               cmh->cmsg_level = SOL_IP;
-               cmh->cmsg_type = IP_PKTINFO;
-               pki->ipi_ifindex = 0;
-               pki->ipi_spec_dst.s_addr = rqstp->rq_daddr;
+               struct msghdr msg = {
+                       .msg_name       = &rqstp->rq_addr,
+                       .msg_namelen    = rqstp->rq_addrlen,
+                       .msg_control    = cmh,
+                       .msg_controllen = sizeof(buffer),
+                       .msg_flags      = MSG_MORE,
+               };
+
+               svc_set_cmsg_data(rqstp, cmh);
 
                if (sock_sendmsg(sock, &msg, 0) < 0)
                        goto out;
@@ -763,7 +794,7 @@ svc_udp_recvfrom(struct svc_rqst *rqstp)
        rqstp->rq_addrlen = sizeof(struct sockaddr_in);
 
        /* Remember which interface received this request */
-       rqstp->rq_daddr = skb->nh.iph->daddr;
+       rqstp->rq_daddr.addr.s_addr = skb->nh.iph->daddr;
 
        if (skb_is_nonlinear(skb)) {
                /* we have to copy */
@@ -907,13 +938,30 @@ svc_tcp_data_ready(struct sock *sk, int count)
                wake_up_interruptible(sk->sk_sleep);
 }
 
+static inline int svc_port_is_privileged(struct sockaddr *sin)
+{
+       switch (sin->sa_family) {
+       case AF_INET:
+               return ntohs(((struct sockaddr_in *)sin)->sin_port)
+                       < PROT_SOCK;
+#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
+       case AF_INET6:
+               return ntohs(((struct sockaddr_in6 *)sin)->sin6_port)
+                       < PROT_SOCK;
+#endif
+       default:
+               return 0;
+       }
+}
+
 /*
  * Accept a TCP connection
  */
 static void
 svc_tcp_accept(struct svc_sock *svsk)
 {
-       struct sockaddr_in sin;
+       struct sockaddr_storage addr;
+       struct sockaddr *sin = (struct sockaddr *) &addr;
        struct svc_serv *serv = svsk->sk_server;
        struct socket   *sock = svsk->sk_sock;
        struct socket   *newsock;
@@ -940,8 +988,7 @@ svc_tcp_accept(struct svc_sock *svsk)
        set_bit(SK_CONN, &svsk->sk_flags);
        svc_sock_enqueue(svsk);
 
-       slen = sizeof(sin);
-       err = kernel_getpeername(newsock, (struct sockaddr *) &sin, &slen);
+       err = kernel_getpeername(newsock, sin, &slen);
        if (err < 0) {
                if (net_ratelimit())
                        printk(KERN_WARNING "%s: peername failed (err %d)!\n",
@@ -953,16 +1000,14 @@ svc_tcp_accept(struct svc_sock *svsk)
         * hosts here, but when we get encryption, the IP of the host won't
         * tell us anything.  For now just warn about unpriv connections.
         */
-       if (ntohs(sin.sin_port) >= 1024) {
+       if (!svc_port_is_privileged(sin)) {
                dprintk(KERN_WARNING
                        "%s: connect from unprivileged port: %s\n",
                        serv->sv_name,
-                       __svc_print_addr((struct sockaddr *) &sin, buf,
-                                                               sizeof(buf)));
+                       __svc_print_addr(sin, buf, sizeof(buf)));
        }
        dprintk("%s: connect from %s\n", serv->sv_name,
-               __svc_print_addr((struct sockaddr *) &sin, buf,
-                                sizeof(buf)));
+               __svc_print_addr(sin, buf, sizeof(buf)));
 
        /* make sure that a write doesn't block forever when
         * low on memory
@@ -972,7 +1017,7 @@ svc_tcp_accept(struct svc_sock *svsk)
        if (!(newsvsk = svc_setup_socket(serv, newsock, &err,
                                 (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY))))
                goto failed;
-       memcpy(&newsvsk->sk_remote, &sin, slen);
+       memcpy(&newsvsk->sk_remote, sin, slen);
        newsvsk->sk_remotelen = slen;
 
        svc_sock_received(newsvsk);
@@ -1303,7 +1348,6 @@ int
 svc_recv(struct svc_rqst *rqstp, long timeout)
 {
        struct svc_sock         *svsk = NULL;
-       struct sockaddr_in      *sin = svc_addr_in(rqstp);
        struct svc_serv         *serv = rqstp->rq_server;
        struct svc_pool         *pool = rqstp->rq_pool;
        int                     len, i;
@@ -1400,7 +1444,7 @@ svc_recv(struct svc_rqst *rqstp, long timeout)
        svsk->sk_lastrecv = get_seconds();
        clear_bit(SK_OLD, &svsk->sk_flags);
 
-       rqstp->rq_secure = ntohs(sin->sin_port) < PROT_SOCK;
+       rqstp->rq_secure = svc_port_is_privileged(svc_addr(rqstp));
        rqstp->rq_chandle.defer = svc_defer;
 
        if (serv->sv_stats)