Skip to content

Commit 65a24a8

Browse files
committed
Make download_to_path_with_backend async
1 parent c11f252 commit 65a24a8

File tree

4 files changed

+105
-96
lines changed

4 files changed

+105
-96
lines changed

download/src/lib.rs

+65-86
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use std::path::Path;
55

66
use anyhow::Context;
77
pub use anyhow::Result;
8-
use tokio::{runtime::Handle, task};
98
use url::Url;
109

1110
mod errors;
@@ -50,7 +49,7 @@ async fn download_with_backend(
5049

5150
type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>;
5251

53-
pub fn download_to_path_with_backend(
52+
pub async fn download_to_path_with_backend(
5453
backend: Backend,
5554
url: &Url,
5655
path: &Path,
@@ -62,100 +61,80 @@ pub fn download_to_path_with_backend(
6261
use std::fs::OpenOptions;
6362
use std::io::{Read, Seek, SeekFrom, Write};
6463

65-
|| -> Result<()> {
66-
let (file, resume_from) = if resume_from_partial {
67-
let possible_partial = OpenOptions::new().read(true).open(path);
68-
69-
let downloaded_so_far = if let Ok(mut partial) = possible_partial {
70-
if let Some(cb) = callback {
71-
cb(Event::ResumingPartialDownload)?;
72-
73-
let mut buf = vec![0; 32768];
74-
let mut downloaded_so_far = 0;
75-
loop {
76-
let n = partial.read(&mut buf)?;
77-
downloaded_so_far += n as u64;
78-
if n == 0 {
79-
break;
64+
(|| {
65+
async move {
66+
let (file, resume_from) = if resume_from_partial {
67+
// TODO: blocking call
68+
let possible_partial = OpenOptions::new().read(true).open(path);
69+
70+
let downloaded_so_far = if let Ok(mut partial) = possible_partial {
71+
if let Some(cb) = callback {
72+
cb(Event::ResumingPartialDownload)?;
73+
74+
let mut buf = vec![0; 32768];
75+
let mut downloaded_so_far = 0;
76+
loop {
77+
let n = partial.read(&mut buf)?;
78+
downloaded_so_far += n as u64;
79+
if n == 0 {
80+
break;
81+
}
82+
cb(Event::DownloadDataReceived(&buf[..n]))?;
8083
}
81-
cb(Event::DownloadDataReceived(&buf[..n]))?;
82-
}
8384

84-
downloaded_so_far
85+
downloaded_so_far
86+
} else {
87+
let file_info = partial.metadata()?;
88+
file_info.len()
89+
}
8590
} else {
86-
let file_info = partial.metadata()?;
87-
file_info.len()
88-
}
89-
} else {
90-
0
91-
};
92-
93-
let mut possible_partial = OpenOptions::new()
94-
.write(true)
95-
.create(true)
96-
.open(path)
97-
.context("error opening file for download")?;
91+
0
92+
};
9893

99-
possible_partial.seek(SeekFrom::End(0))?;
100-
101-
(possible_partial, downloaded_so_far)
102-
} else {
103-
(
104-
OpenOptions::new()
94+
let mut possible_partial = OpenOptions::new()
10595
.write(true)
10696
.create(true)
10797
.open(path)
108-
.context("error creating file for download")?,
109-
0,
110-
)
111-
};
98+
.context("error opening file for download")?;
11299

113-
let file = RefCell::new(file);
114-
115-
match Handle::try_current() {
116-
Ok(current) => {
117-
// hide the asyncness for now.
118-
task::block_in_place(|| {
119-
current.block_on(download_with_backend(backend, url, resume_from, &|event| {
120-
if let Event::DownloadDataReceived(data) = event {
121-
file.borrow_mut()
122-
.write_all(data)
123-
.context("unable to write download to disk")?;
124-
}
125-
match callback {
126-
Some(cb) => cb(event),
127-
None => Ok(()),
128-
}
129-
}))
130-
})
131-
}
132-
Err(_) => {
133-
// Make a runtime to hide the asyncness.
134-
tokio::runtime::Runtime::new()?.block_on(download_with_backend(
135-
backend,
136-
url,
137-
resume_from,
138-
&|event| {
139-
if let Event::DownloadDataReceived(data) = event {
140-
file.borrow_mut()
141-
.write_all(data)
142-
.context("unable to write download to disk")?;
143-
}
144-
match callback {
145-
Some(cb) => cb(event),
146-
None => Ok(()),
147-
}
148-
},
149-
))
150-
}
151-
}?;
100+
possible_partial.seek(SeekFrom::End(0))?;
152101

153-
file.borrow_mut()
154-
.sync_data()
155-
.context("unable to sync download to disk")?;
102+
(possible_partial, downloaded_so_far)
103+
} else {
104+
(
105+
OpenOptions::new()
106+
.write(true)
107+
.create(true)
108+
.open(path)
109+
.context("error creating file for download")?,
110+
0,
111+
)
112+
};
156113

157-
Ok(())
158-
}()
114+
let file = RefCell::new(file);
115+
116+
// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
117+
download_with_backend(backend, url, resume_from, &|event| {
118+
if let Event::DownloadDataReceived(data) = event {
119+
file.borrow_mut()
120+
.write_all(data)
121+
.context("unable to write download to disk")?;
122+
}
123+
match callback {
124+
Some(cb) => cb(event),
125+
None => Ok(()),
126+
}
127+
})
128+
.await?;
129+
130+
file.borrow_mut()
131+
.sync_data()
132+
.context("unable to sync download to disk")?;
133+
134+
Ok::<(), anyhow::Error>(())
135+
}
136+
})()
137+
.await
159138
.map_err(|e| {
160139
// TODO: We currently clear up the cached download on any error, should we restrict it to a subset?
161140
if let Err(file_err) = remove_file(path).context("cleaning up cached downloads") {

download/tests/download-curl-resume.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use download::*;
1010
mod support;
1111
use crate::support::{serve_file, tmp_dir, write_file};
1212

13-
#[test]
14-
fn partially_downloaded_file_gets_resumed_from_byte_offset() {
13+
#[tokio::test]
14+
async fn partially_downloaded_file_gets_resumed_from_byte_offset() {
1515
let tmpdir = tmp_dir();
1616
let from_path = tmpdir.path().join("download-source");
1717
write_file(&from_path, "xxx45");
@@ -21,13 +21,14 @@ fn partially_downloaded_file_gets_resumed_from_byte_offset() {
2121

2222
let from_url = Url::from_file_path(&from_path).unwrap();
2323
download_to_path_with_backend(Backend::Curl, &from_url, &target_path, true, None)
24+
.await
2425
.expect("Test download failed");
2526

2627
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
2728
}
2829

29-
#[test]
30-
fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
30+
#[tokio::test]
31+
async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
3132
let tmpdir = tmp_dir();
3233
let target_path = tmpdir.path().join("downloaded");
3334
write_file(&target_path, "123");
@@ -66,6 +67,7 @@ fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
6667
Ok(())
6768
}),
6869
)
70+
.await
6971
.expect("Test download failed");
7072

7173
assert!(callback_partial.into_inner());

download/tests/download-reqwest-resume.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use download::*;
1010
mod support;
1111
use crate::support::{serve_file, tmp_dir, write_file};
1212

13-
#[test]
14-
fn resume_partial_from_file_url() {
13+
#[tokio::test]
14+
async fn resume_partial_from_file_url() {
1515
let tmpdir = tmp_dir();
1616
let from_path = tmpdir.path().join("download-source");
1717
write_file(&from_path, "xxx45");
@@ -27,13 +27,14 @@ fn resume_partial_from_file_url() {
2727
true,
2828
None,
2929
)
30+
.await
3031
.expect("Test download failed");
3132

3233
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
3334
}
3435

35-
#[test]
36-
fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
36+
#[tokio::test]
37+
async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
3738
let tmpdir = tmp_dir();
3839
let target_path = tmpdir.path().join("downloaded");
3940
write_file(&target_path, "123");
@@ -72,6 +73,7 @@ fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
7273
Ok(())
7374
}),
7475
)
76+
.await
7577
.expect("Test download failed");
7678

7779
assert!(callback_partial.into_inner());

src/utils/utils.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::env;
22
use std::fs::{self, File};
3+
use std::future::Future;
34
use std::io::{self, BufReader, Write};
45
use std::path::{Path, PathBuf};
56

@@ -8,6 +9,8 @@ use home::env as home;
89
use retry::delay::{jitter, Fibonacci};
910
use retry::{retry, OperationResult};
1011
use sha2::Sha256;
12+
use tokio::runtime::Handle;
13+
use tokio::task;
1114
use url::Url;
1215

1316
use crate::currentprocess::{cwdsource::CurrentDirSource, varsource::VarSource};
@@ -238,14 +241,37 @@ fn download_file_(
238241
(Backend::Reqwest(tls_backend), Notification::UsingReqwest)
239242
};
240243
notify_handler(notification);
241-
let res =
242-
download_to_path_with_backend(backend, url, path, resume_from_partial, Some(callback));
244+
let res = run_future(download_to_path_with_backend(
245+
backend,
246+
url,
247+
path,
248+
resume_from_partial,
249+
Some(callback),
250+
));
243251

244252
notify_handler(Notification::DownloadFinished);
245253

246254
res
247255
}
248256

257+
/// Temporary thunk to support asyncifying from underneath.
258+
pub(crate) fn run_future<F, R, E>(f: F) -> Result<R, E>
259+
where
260+
F: Future<Output = Result<R, E>>,
261+
E: std::convert::From<std::io::Error>,
262+
{
263+
match Handle::try_current() {
264+
Ok(current) => {
265+
// hide the asyncness for now.
266+
task::block_in_place(|| current.block_on(f))
267+
}
268+
Err(_) => {
269+
// Make a runtime to hide the asyncness.
270+
tokio::runtime::Runtime::new()?.block_on(f)
271+
}
272+
}
273+
}
274+
249275
pub(crate) fn parse_url(url: &str) -> Result<Url> {
250276
Url::parse(url).with_context(|| format!("failed to parse url: {url}"))
251277
}

0 commit comments

Comments
 (0)