From a40be0857c7bf48e39f815417b0b5293cd8ed1aa Mon Sep 17 00:00:00 2001
From: Tyler Julian <tjulian@uber.com>
Date: Tue, 10 Jan 2017 19:11:56 -0800
Subject: [PATCH] libstd/net: Add `peek` APIs to UdpSocket and TcpStream

These methods enable socket reads without side-effects. That is,
repeated calls to peek() return identical data. This is accomplished
by providing the POSIX flag MSG_PEEK to the underlying socket read
operations.

This also moves the current implementation of recv_from out of the
platform-independent sys_common and into respective sys/windows and
sys/unix implementations. This allows for more platform-dependent
implementations.
---
 src/liblibc                   |  2 +-
 src/libstd/lib.rs             |  1 +
 src/libstd/net/tcp.rs         | 54 +++++++++++++++++++
 src/libstd/net/udp.rs         | 97 +++++++++++++++++++++++++++++++++++
 src/libstd/sys/unix/net.rs    | 45 ++++++++++++++--
 src/libstd/sys/windows/c.rs   |  1 +
 src/libstd/sys/windows/net.rs | 44 +++++++++++++++-
 src/libstd/sys_common/net.rs  | 24 +++++----
 8 files changed, 251 insertions(+), 17 deletions(-)

diff --git a/src/liblibc b/src/liblibc
index 7d57bdcdbb565..cb7f66732175e 160000
--- a/src/liblibc
+++ b/src/liblibc
@@ -1 +1 @@
-Subproject commit 7d57bdcdbb56540f37afe5a934ce12d33a6ca7fc
+Subproject commit cb7f66732175e6171587ed69656b7aae7dd2e6ec
diff --git a/src/libstd/lib.rs b/src/libstd/lib.rs
index 9557c520c5071..3c06409e3b18e 100644
--- a/src/libstd/lib.rs
+++ b/src/libstd/lib.rs
@@ -275,6 +275,7 @@
 #![feature(oom)]
 #![feature(optin_builtin_traits)]
 #![feature(panic_unwind)]
+#![feature(peek)]
 #![feature(placement_in_syntax)]
 #![feature(prelude_import)]
 #![feature(pub_restricted)]
diff --git a/src/libstd/net/tcp.rs b/src/libstd/net/tcp.rs
index ed1f08f9c9090..ba6160cc72331 100644
--- a/src/libstd/net/tcp.rs
+++ b/src/libstd/net/tcp.rs
@@ -296,6 +296,29 @@ impl TcpStream {
         self.0.write_timeout()
     }
 
+    /// Receives data on the socket from the remote adress to which it is
+    /// connected, without removing that data from the queue. On success,
+    /// returns the number of bytes peeked.
+    ///
+    /// Successive calls return the same data. This is accomplished by passing
+    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
+    ///
+    /// # Examples
+    ///
+    /// ```no_run
+    /// #![feature(peek)]
+    /// use std::net::TcpStream;
+    ///
+    /// let stream = TcpStream::connect("127.0.0.1:8000")
+    ///                        .expect("couldn't bind to address");
+    /// let mut buf = [0; 10];
+    /// let len = stream.peek(&mut buf).expect("peek failed");
+    /// ```
+    #[unstable(feature = "peek", issue = "38980")]
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.0.peek(buf)
+    }
+
     /// Sets the value of the `TCP_NODELAY` option on this socket.
     ///
     /// If set, this option disables the Nagle algorithm. This means that
@@ -1405,4 +1428,35 @@ mod tests {
             Err(e) => panic!("unexpected error {}", e),
         }
     }
+
+    #[test]
+    fn peek() {
+        each_ip(&mut |addr| {
+            let (txdone, rxdone) = channel();
+
+            let srv = t!(TcpListener::bind(&addr));
+            let _t = thread::spawn(move|| {
+                let mut cl = t!(srv.accept()).0;
+                cl.write(&[1,3,3,7]).unwrap();
+                t!(rxdone.recv());
+            });
+
+            let mut c = t!(TcpStream::connect(&addr));
+            let mut b = [0; 10];
+            for _ in 1..3 {
+                let len = c.peek(&mut b).unwrap();
+                assert_eq!(len, 4);
+            }
+            let len = c.read(&mut b).unwrap();
+            assert_eq!(len, 4);
+
+            t!(c.set_nonblocking(true));
+            match c.peek(&mut b) {
+                Ok(_) => panic!("expected error"),
+                Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
+                Err(e) => panic!("unexpected error {}", e),
+            }
+            t!(txdone.send(()));
+        })
+    }
 }
diff --git a/src/libstd/net/udp.rs b/src/libstd/net/udp.rs
index f8a5ec0b3791e..2f28f475dc88b 100644
--- a/src/libstd/net/udp.rs
+++ b/src/libstd/net/udp.rs
@@ -83,6 +83,30 @@ impl UdpSocket {
         self.0.recv_from(buf)
     }
 
+    /// Receives data from the socket, without removing it from the queue.
+    ///
+    /// Successive calls return the same data. This is accomplished by passing
+    /// `MSG_PEEK` as a flag to the underlying `recvfrom` system call.
+    ///
+    /// On success, returns the number of bytes peeked and the address from
+    /// whence the data came.
+    ///
+    /// # Examples
+    ///
+    /// ```no_run
+    /// #![feature(peek)]
+    /// use std::net::UdpSocket;
+    ///
+    /// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
+    /// let mut buf = [0; 10];
+    /// let (number_of_bytes, src_addr) = socket.peek_from(&mut buf)
+    ///                                         .expect("Didn't receive data");
+    /// ```
+    #[unstable(feature = "peek", issue = "38980")]
+    pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.0.peek_from(buf)
+    }
+
     /// Sends data on the socket to the given address. On success, returns the
     /// number of bytes written.
     ///
@@ -579,6 +603,37 @@ impl UdpSocket {
         self.0.recv(buf)
     }
 
+    /// Receives data on the socket from the remote adress to which it is
+    /// connected, without removing that data from the queue. On success,
+    /// returns the number of bytes peeked.
+    ///
+    /// Successive calls return the same data. This is accomplished by passing
+    /// `MSG_PEEK` as a flag to the underlying `recv` system call.
+    ///
+    /// # Errors
+    ///
+    /// This method will fail if the socket is not connected. The `connect` method
+    /// will connect this socket to a remote address.
+    ///
+    /// # Examples
+    ///
+    /// ```no_run
+    /// #![feature(peek)]
+    /// use std::net::UdpSocket;
+    ///
+    /// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
+    /// socket.connect("127.0.0.1:8080").expect("connect function failed");
+    /// let mut buf = [0; 10];
+    /// match socket.peek(&mut buf) {
+    ///     Ok(received) => println!("received {} bytes", received),
+    ///     Err(e) => println!("peek function failed: {:?}", e),
+    /// }
+    /// ```
+    #[unstable(feature = "peek", issue = "38980")]
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.0.peek(buf)
+    }
+
     /// Moves this UDP socket into or out of nonblocking mode.
     ///
     /// On Unix this corresponds to calling fcntl, and on Windows this
@@ -869,6 +924,48 @@ mod tests {
         assert_eq!(b"hello world", &buf[..]);
     }
 
+    #[test]
+    fn connect_send_peek_recv() {
+        each_ip(&mut |addr, _| {
+            let socket = t!(UdpSocket::bind(&addr));
+            t!(socket.connect(addr));
+
+            t!(socket.send(b"hello world"));
+
+            for _ in 1..3 {
+                let mut buf = [0; 11];
+                let size = t!(socket.peek(&mut buf));
+                assert_eq!(b"hello world", &buf[..]);
+                assert_eq!(size, 11);
+            }
+
+            let mut buf = [0; 11];
+            let size = t!(socket.recv(&mut buf));
+            assert_eq!(b"hello world", &buf[..]);
+            assert_eq!(size, 11);
+        })
+    }
+
+    #[test]
+    fn peek_from() {
+        each_ip(&mut |addr, _| {
+            let socket = t!(UdpSocket::bind(&addr));
+            t!(socket.send_to(b"hello world", &addr));
+
+            for _ in 1..3 {
+                let mut buf = [0; 11];
+                let (size, _) = t!(socket.peek_from(&mut buf));
+                assert_eq!(b"hello world", &buf[..]);
+                assert_eq!(size, 11);
+            }
+
+            let mut buf = [0; 11];
+            let (size, _) = t!(socket.recv_from(&mut buf));
+            assert_eq!(b"hello world", &buf[..]);
+            assert_eq!(size, 11);
+        })
+    }
+
     #[test]
     fn ttl() {
         let ttl = 100;
diff --git a/src/libstd/sys/unix/net.rs b/src/libstd/sys/unix/net.rs
index ad287bbec3889..5efddca110f05 100644
--- a/src/libstd/sys/unix/net.rs
+++ b/src/libstd/sys/unix/net.rs
@@ -10,12 +10,13 @@
 
 use ffi::CStr;
 use io;
-use libc::{self, c_int, size_t, sockaddr, socklen_t, EAI_SYSTEM};
+use libc::{self, c_int, c_void, size_t, sockaddr, socklen_t, EAI_SYSTEM, MSG_PEEK};
+use mem;
 use net::{SocketAddr, Shutdown};
 use str;
 use sys::fd::FileDesc;
 use sys_common::{AsInner, FromInner, IntoInner};
-use sys_common::net::{getsockopt, setsockopt};
+use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
 use time::Duration;
 
 pub use sys::{cvt, cvt_r};
@@ -155,8 +156,46 @@ impl Socket {
         self.0.duplicate().map(Socket)
     }
 
+    fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
+        let ret = cvt(unsafe {
+            libc::recv(self.0.raw(),
+                       buf.as_mut_ptr() as *mut c_void,
+                       buf.len(),
+                       flags)
+        })?;
+        Ok(ret as usize)
+    }
+
     pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
-        self.0.read(buf)
+        self.recv_with_flags(buf, 0)
+    }
+
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.recv_with_flags(buf, MSG_PEEK)
+    }
+
+    fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
+                            -> io::Result<(usize, SocketAddr)> {
+        let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
+        let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;
+
+        let n = cvt(unsafe {
+            libc::recvfrom(self.0.raw(),
+                        buf.as_mut_ptr() as *mut c_void,
+                        buf.len(),
+                        flags,
+                        &mut storage as *mut _ as *mut _,
+                        &mut addrlen)
+        })?;
+        Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
+    }
+
+    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, 0)
+    }
+
+    pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, MSG_PEEK)
     }
 
     pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
diff --git a/src/libstd/sys/windows/c.rs b/src/libstd/sys/windows/c.rs
index dc7b2fc9a6bab..9f03f5c9717fc 100644
--- a/src/libstd/sys/windows/c.rs
+++ b/src/libstd/sys/windows/c.rs
@@ -244,6 +244,7 @@ pub const IP_ADD_MEMBERSHIP: c_int = 12;
 pub const IP_DROP_MEMBERSHIP: c_int = 13;
 pub const IPV6_ADD_MEMBERSHIP: c_int = 12;
 pub const IPV6_DROP_MEMBERSHIP: c_int = 13;
+pub const MSG_PEEK: c_int = 0x2;
 
 #[repr(C)]
 pub struct ip_mreq {
diff --git a/src/libstd/sys/windows/net.rs b/src/libstd/sys/windows/net.rs
index aca6994503ff8..adf6210d82e89 100644
--- a/src/libstd/sys/windows/net.rs
+++ b/src/libstd/sys/windows/net.rs
@@ -147,12 +147,12 @@ impl Socket {
         Ok(socket)
     }
 
-    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
+    fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
         // On unix when a socket is shut down all further reads return 0, so we
         // do the same on windows to map a shut down socket to returning EOF.
         let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
         unsafe {
-            match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, 0) {
+            match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) {
                 -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
                 -1 => Err(last_error()),
                 n => Ok(n as usize)
@@ -160,6 +160,46 @@ impl Socket {
         }
     }
 
+    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.recv_with_flags(buf, 0)
+    }
+
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.recv_with_flags(buf, c::MSG_PEEK)
+    }
+
+    fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
+                            -> io::Result<(usize, SocketAddr)> {
+        let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() };
+        let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
+        let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
+
+        // On unix when a socket is shut down all further reads return 0, so we
+        // do the same on windows to map a shut down socket to returning EOF.
+        unsafe {
+            match c::recvfrom(self.0,
+                              buf.as_mut_ptr() as *mut c_void,
+                              len,
+                              flags,
+                              &mut storage as *mut _ as *mut _,
+                              &mut addrlen) {
+                -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => {
+                    Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
+                },
+                -1 => Err(last_error()),
+                n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
+            }
+        }
+    }
+
+    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, 0)
+    }
+
+    pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, c::MSG_PEEK)
+    }
+
     pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
         let mut me = self;
         (&mut me).read_to_end(buf)
diff --git a/src/libstd/sys_common/net.rs b/src/libstd/sys_common/net.rs
index 10ad61f4c800c..3cdeb51194575 100644
--- a/src/libstd/sys_common/net.rs
+++ b/src/libstd/sys_common/net.rs
@@ -91,7 +91,7 @@ fn sockname<F>(f: F) -> io::Result<SocketAddr>
     }
 }
 
-fn sockaddr_to_addr(storage: &c::sockaddr_storage,
+pub fn sockaddr_to_addr(storage: &c::sockaddr_storage,
                     len: usize) -> io::Result<SocketAddr> {
     match storage.ss_family as c_int {
         c::AF_INET => {
@@ -222,6 +222,10 @@ impl TcpStream {
         self.inner.timeout(c::SO_SNDTIMEO)
     }
 
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.inner.peek(buf)
+    }
+
     pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
         self.inner.read(buf)
     }
@@ -441,17 +445,11 @@ impl UdpSocket {
     }
 
     pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
-        let mut storage: c::sockaddr_storage = unsafe { mem::zeroed() };
-        let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
-        let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
+        self.inner.recv_from(buf)
+    }
 
-        let n = cvt(unsafe {
-            c::recvfrom(*self.inner.as_inner(),
-                        buf.as_mut_ptr() as *mut c_void,
-                        len, 0,
-                        &mut storage as *mut _ as *mut _, &mut addrlen)
-        })?;
-        Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
+    pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.inner.peek_from(buf)
     }
 
     pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> {
@@ -578,6 +576,10 @@ impl UdpSocket {
         self.inner.read(buf)
     }
 
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.inner.peek(buf)
+    }
+
     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
         let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
         let ret = cvt(unsafe {