Skip to content

Commit 1e64f1e

Browse files
authored
fix: fail if an on_blob function does not read all the content (#139)
closes #127
1 parent 519f8c5 commit 1e64f1e

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

src/get.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use anyhow::{anyhow, bail, ensure, Result};
1313
use bytes::BytesMut;
1414
use futures::Future;
1515
use postcard::experimental::max_size::MaxSize;
16-
use tokio::io::{AsyncRead, ReadBuf};
16+
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
1717
use tracing::debug;
1818

1919
use crate::bao_slice_decoder::AsyncSliceDecoder;
@@ -207,7 +207,12 @@ where
207207
"downloaded more than {total_blobs_size}"
208208
);
209209
remaining_size -= size;
210-
let blob_reader = on_blob(blob.hash, blob_reader, blob.name).await?;
210+
let mut blob_reader =
211+
on_blob(blob.hash, blob_reader, blob.name).await?;
212+
213+
if blob_reader.read_exact(&mut [0u8; 1]).await.is_ok() {
214+
bail!("`on_blob` callback did not fully read the blob content")
215+
}
211216
reader = blob_reader.into_inner();
212217
}
213218
}

src/lib.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ mod tests {
2626
use rand::RngCore;
2727
use testdir::testdir;
2828
use tokio::fs;
29-
use tokio::io::{self, AsyncReadExt};
29+
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
3030

3131
use crate::protocol::AuthToken;
3232
use crate::provider::{create_collection, Event, Provider};
@@ -329,4 +329,53 @@ mod tests {
329329
// Unwrap the JoinHandle, then the result of the Provider
330330
supervisor.await.unwrap().unwrap();
331331
}
332+
333+
#[tokio::test]
334+
async fn test_blob_reader_partial() -> Result<()> {
335+
// Prepare a Provider transferring a file.
336+
let dir = testdir!();
337+
let src0 = dir.join("src0");
338+
let src1 = dir.join("src1");
339+
{
340+
let content = vec![1u8; 1000];
341+
let mut f = tokio::fs::File::create(&src0).await?;
342+
for _ in 0..10 {
343+
f.write_all(&content).await?;
344+
}
345+
}
346+
fs::write(&src1, "hello world").await?;
347+
let (db, hash) = create_collection(vec![src0.into(), src1.into()]).await?;
348+
let provider = Provider::builder(db)
349+
.bind_addr("127.0.0.1:0".parse().unwrap())
350+
.spawn()?;
351+
let auth_token = provider.auth_token();
352+
let provider_addr = provider.listen_addr();
353+
354+
let timeout = tokio::time::timeout(
355+
std::time::Duration::from_secs(10),
356+
get::run(
357+
hash,
358+
auth_token,
359+
get::Options {
360+
addr: provider_addr,
361+
peer_id: None,
362+
},
363+
|| async move { Ok(()) },
364+
|_collection| async move { Ok(()) },
365+
|_hash, stream, _name| async move {
366+
// evil: do nothing with the stream!
367+
Ok(stream)
368+
},
369+
),
370+
)
371+
.await;
372+
provider.shutdown();
373+
374+
let err = timeout.expect(
375+
"`get` function is hanging, make sure we are handling misbehaving `on_blob` functions",
376+
);
377+
378+
err.expect_err("expected an error when passing in a misbehaving `on_blob` function");
379+
Ok(())
380+
}
332381
}

0 commit comments

Comments
 (0)