Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/rustac/rustac.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DuckdbClient:
Args:
extension_directory: A non-standard extension directory to use.
extensions: A list of extensions to LOAD on client initialization.
install_extensions: Whether to install the spatial and icu extensions on client initialization.
install_extensions: Whether to install the required extensions on client initialization.
use_hive_partitioning: Whether to use hive partitioning for geoparquet queries.
"""

Expand Down
14 changes: 9 additions & 5 deletions src/duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use pyo3_arrow::PyTable;
use stac_duckdb::Client;
use std::{path::PathBuf, sync::Mutex};

const REQUIRED_EXTENSIONS: [&str; 3] = ["spatial", "icu", "parquet"];

#[pyclass(frozen)]
pub struct DuckdbClient(Mutex<Client>);

Expand All @@ -34,14 +36,16 @@ impl DuckdbClient {
)?;
}
if install_extensions {
connection.execute("INSTALL spatial", [])?;
connection.execute("INSTALL icu", [])?;
for extension in REQUIRED_EXTENSIONS {
connection.execute(&format!("INSTALL {extension}"), [])?;
}
}
for extension in extensions {
connection.execute(&format!("LOAD '{}'", extension), [])?;
connection.execute(&format!("LOAD '{extension}'"), [])?;
}
for extension in REQUIRED_EXTENSIONS {
connection.execute(&format!("LOAD {extension}"), [])?;
}
connection.execute("LOAD spatial", [])?;
connection.execute("LOAD icu", [])?;
let mut client = Client::from(connection);
client.use_hive_partitioning = use_hive_partitioning;
Ok(DuckdbClient(Mutex::new(client)))
Expand Down