tokio/io/util/buf_reader.rs
1use crate::io::util::DEFAULT_BUF_SIZE;
2use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
3
4use pin_project_lite::pin_project;
5use std::io::{self, IoSlice, SeekFrom};
6use std::pin::Pin;
7use std::task::{ready, Context, Poll};
8use std::{cmp, fmt, mem};
9
10pin_project! {
11    /// The `BufReader` struct adds buffering to any reader.
12    ///
13    /// It can be excessively inefficient to work directly with a [`AsyncRead`]
14    /// instance. A `BufReader` performs large, infrequent reads on the underlying
15    /// [`AsyncRead`] and maintains an in-memory buffer of the results.
16    ///
17    /// `BufReader` can improve the speed of programs that make *small* and
18    /// *repeated* read calls to the same file or network socket. It does not
19    /// help when reading very large amounts at once, or reading just one or a few
20    /// times. It also provides no advantage when reading from a source that is
21    /// already in memory, like a `Vec<u8>`.
22    ///
23    /// When the `BufReader` is dropped, the contents of its buffer will be
24    /// discarded. Creating multiple instances of a `BufReader` on the same
25    /// stream can cause data loss.
26    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
27    pub struct BufReader<R> {
28        #[pin]
29        pub(super) inner: R,
30        pub(super) buf: Box<[u8]>,
31        pub(super) pos: usize,
32        pub(super) cap: usize,
33        pub(super) seek_state: SeekState,
34    }
35}
36
37impl<R: AsyncRead> BufReader<R> {
38    /// Creates a new `BufReader` with a default buffer capacity. The default is currently 8 KB,
39    /// but may change in the future.
40    pub fn new(inner: R) -> Self {
41        Self::with_capacity(DEFAULT_BUF_SIZE, inner)
42    }
43
44    /// Creates a new `BufReader` with the specified buffer capacity.
45    pub fn with_capacity(capacity: usize, inner: R) -> Self {
46        let buffer = vec![0; capacity];
47        Self {
48            inner,
49            buf: buffer.into_boxed_slice(),
50            pos: 0,
51            cap: 0,
52            seek_state: SeekState::Init,
53        }
54    }
55
56    /// Gets a reference to the underlying reader.
57    ///
58    /// It is inadvisable to directly read from the underlying reader.
59    pub fn get_ref(&self) -> &R {
60        &self.inner
61    }
62
63    /// Gets a mutable reference to the underlying reader.
64    ///
65    /// It is inadvisable to directly read from the underlying reader.
66    pub fn get_mut(&mut self) -> &mut R {
67        &mut self.inner
68    }
69
70    /// Gets a pinned mutable reference to the underlying reader.
71    ///
72    /// It is inadvisable to directly read from the underlying reader.
73    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
74        self.project().inner
75    }
76
77    /// Consumes this `BufReader`, returning the underlying reader.
78    ///
79    /// Note that any leftover data in the internal buffer is lost.
80    pub fn into_inner(self) -> R {
81        self.inner
82    }
83
84    /// Returns a reference to the internally buffered data.
85    ///
86    /// Unlike `fill_buf`, this will not attempt to fill the buffer if it is empty.
87    pub fn buffer(&self) -> &[u8] {
88        &self.buf[self.pos..self.cap]
89    }
90
91    /// Invalidates all data in the internal buffer.
92    #[inline]
93    fn discard_buffer(self: Pin<&mut Self>) {
94        let me = self.project();
95        *me.pos = 0;
96        *me.cap = 0;
97    }
98}
99
100impl<R: AsyncRead> AsyncRead for BufReader<R> {
101    fn poll_read(
102        mut self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104        buf: &mut ReadBuf<'_>,
105    ) -> Poll<io::Result<()>> {
106        // If we don't have any buffered data and we're doing a massive read
107        // (larger than our internal buffer), bypass our internal buffer
108        // entirely.
109        if self.pos == self.cap && buf.remaining() >= self.buf.len() {
110            let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf));
111            self.discard_buffer();
112            return Poll::Ready(res);
113        }
114        let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
115        let amt = std::cmp::min(rem.len(), buf.remaining());
116        buf.put_slice(&rem[..amt]);
117        self.consume(amt);
118        Poll::Ready(Ok(()))
119    }
120}
121
122impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
123    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
124        let me = self.project();
125
126        // If we've reached the end of our internal buffer then we need to fetch
127        // some more data from the underlying reader.
128        // Branch using `>=` instead of the more correct `==`
129        // to tell the compiler that the pos..cap slice is always valid.
130        if *me.pos >= *me.cap {
131            debug_assert!(*me.pos == *me.cap);
132            let mut buf = ReadBuf::new(me.buf);
133            ready!(me.inner.poll_read(cx, &mut buf))?;
134            *me.cap = buf.filled().len();
135            *me.pos = 0;
136        }
137        Poll::Ready(Ok(&me.buf[*me.pos..*me.cap]))
138    }
139
140    fn consume(self: Pin<&mut Self>, amt: usize) {
141        let me = self.project();
142        *me.pos = cmp::min(*me.pos + amt, *me.cap);
143    }
144}
145
146#[derive(Debug, Clone, Copy)]
147pub(super) enum SeekState {
148    /// `start_seek` has not been called.
149    Init,
150    /// `start_seek` has been called, but `poll_complete` has not yet been called.
151    Start(SeekFrom),
152    /// Waiting for completion of the first `poll_complete` in the `n.checked_sub(remainder).is_none()` branch.
153    PendingOverflowed(i64),
154    /// Waiting for completion of `poll_complete`.
155    Pending,
156}
157
158/// Seeks to an offset, in bytes, in the underlying reader.
159///
160/// The position used for seeking with `SeekFrom::Current(_)` is the
161/// position the underlying reader would be at if the `BufReader` had no
162/// internal buffer.
163///
164/// Seeking always discards the internal buffer, even if the seek position
165/// would otherwise fall within it. This guarantees that calling
166/// `.into_inner()` immediately after a seek yields the underlying reader
167/// at the same position.
168///
169/// See [`AsyncSeek`] for more details.
170///
171/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)`
172/// where `n` minus the internal buffer length overflows an `i64`, two
173/// seeks will be performed instead of one. If the second seek returns
174/// `Err`, the underlying reader will be left at the same position it would
175/// have if you called `seek` with `SeekFrom::Current(0)`.
176impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> {
177    fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
178        // We needs to call seek operation multiple times.
179        // And we should always call both start_seek and poll_complete,
180        // as start_seek alone cannot guarantee that the operation will be completed.
181        // poll_complete receives a Context and returns a Poll, so it cannot be called
182        // inside start_seek.
183        *self.project().seek_state = SeekState::Start(pos);
184        Ok(())
185    }
186
187    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
188        let res = match mem::replace(self.as_mut().project().seek_state, SeekState::Init) {
189            SeekState::Init => {
190                // 1.x AsyncSeek recommends calling poll_complete before start_seek.
191                // We don't have to guarantee that the value returned by
192                // poll_complete called without start_seek is correct,
193                // so we'll return 0.
194                return Poll::Ready(Ok(0));
195            }
196            SeekState::Start(SeekFrom::Current(n)) => {
197                let remainder = (self.cap - self.pos) as i64;
198                // it should be safe to assume that remainder fits within an i64 as the alternative
199                // means we managed to allocate 8 exbibytes and that's absurd.
200                // But it's not out of the realm of possibility for some weird underlying reader to
201                // support seeking by i64::MIN so we need to handle underflow when subtracting
202                // remainder.
203                if let Some(offset) = n.checked_sub(remainder) {
204                    self.as_mut()
205                        .get_pin_mut()
206                        .start_seek(SeekFrom::Current(offset))?;
207                } else {
208                    // seek backwards by our remainder, and then by the offset
209                    self.as_mut()
210                        .get_pin_mut()
211                        .start_seek(SeekFrom::Current(-remainder))?;
212                    if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() {
213                        *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n);
214                        return Poll::Pending;
215                    }
216
217                    // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676
218                    self.as_mut().discard_buffer();
219
220                    self.as_mut()
221                        .get_pin_mut()
222                        .start_seek(SeekFrom::Current(n))?;
223                }
224                self.as_mut().get_pin_mut().poll_complete(cx)?
225            }
226            SeekState::PendingOverflowed(n) => {
227                if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() {
228                    *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n);
229                    return Poll::Pending;
230                }
231
232                // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676
233                self.as_mut().discard_buffer();
234
235                self.as_mut()
236                    .get_pin_mut()
237                    .start_seek(SeekFrom::Current(n))?;
238                self.as_mut().get_pin_mut().poll_complete(cx)?
239            }
240            SeekState::Start(pos) => {
241                // Seeking with Start/End doesn't care about our buffer length.
242                self.as_mut().get_pin_mut().start_seek(pos)?;
243                self.as_mut().get_pin_mut().poll_complete(cx)?
244            }
245            SeekState::Pending => self.as_mut().get_pin_mut().poll_complete(cx)?,
246        };
247
248        match res {
249            Poll::Ready(res) => {
250                self.discard_buffer();
251                Poll::Ready(Ok(res))
252            }
253            Poll::Pending => {
254                *self.as_mut().project().seek_state = SeekState::Pending;
255                Poll::Pending
256            }
257        }
258    }
259}
260
261impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
262    fn poll_write(
263        self: Pin<&mut Self>,
264        cx: &mut Context<'_>,
265        buf: &[u8],
266    ) -> Poll<io::Result<usize>> {
267        self.get_pin_mut().poll_write(cx, buf)
268    }
269
270    fn poll_write_vectored(
271        self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273        bufs: &[IoSlice<'_>],
274    ) -> Poll<io::Result<usize>> {
275        self.get_pin_mut().poll_write_vectored(cx, bufs)
276    }
277
278    fn is_write_vectored(&self) -> bool {
279        self.get_ref().is_write_vectored()
280    }
281
282    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
283        self.get_pin_mut().poll_flush(cx)
284    }
285
286    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
287        self.get_pin_mut().poll_shutdown(cx)
288    }
289}
290
291impl<R: fmt::Debug> fmt::Debug for BufReader<R> {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        f.debug_struct("BufReader")
294            .field("reader", &self.inner)
295            .field(
296                "buffer",
297                &format_args!("{}/{}", self.cap - self.pos, self.buf.len()),
298            )
299            .finish()
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn assert_unpin() {
309        crate::is_unpin::<BufReader<()>>();
310    }
311}