Skip to content

Refactor low-level column access #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@ macro_rules! panic_on_tskit_error {
};
}

macro_rules! unsafe_tsk_column_access {
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident) => {{
let x = $crate::tsk_id_t::from($i);
if x < $lo || (x as $crate::tsk_size_t) >= $hi {
None
} else {
debug_assert!(!($owner).$array.is_null());
if !$owner.$array.is_null() {
// SAFETY: array is not null
// and we did our best effort
// on bounds checking
Some(unsafe { *$owner.$array.offset(x as isize) })
} else {
None
}
}
}};
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $output_id_type: ty) => {{
let x = $crate::tsk_id_t::from($i);
if x < $lo || (x as $crate::tsk_size_t) >= $hi {
None
} else {
debug_assert!(!($owner).$array.is_null());
if !$owner.$array.is_null() {
// SAFETY: array is not null
// and we did our best effort
// on bounds checking
unsafe { Some(<$output_id_type>::from(*($owner.$array.offset(x as isize)))) }
} else {
None
}
}
}};
}

macro_rules! unsafe_tsk_column_access_and_map_into {
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident) => {{
unsafe_tsk_column_access!($i, $lo, $hi, $owner, $array).map(|v| v.into())
}};
}

macro_rules! unsafe_tsk_ragged_column_access {
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $offset_array: ident, $offset_array_len: ident, $output_id_type: ty) => {{
let i = $crate::SizeType::try_from($i).ok()?;
Expand Down
27 changes: 9 additions & 18 deletions src/edge_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ptr::NonNull;

use crate::bindings as ll_bindings;
use crate::metadata;
use crate::sys;
use crate::Position;
use crate::{tsk_id_t, TskitError};
use crate::{EdgeId, NodeId};
Expand Down Expand Up @@ -180,14 +181,7 @@ impl EdgeTable {
/// * `Some(parent)` if `u` is valid.
/// * `None` otherwise.
pub fn parent<E: Into<EdgeId> + Copy>(&self, row: E) -> Option<NodeId> {
unsafe_tsk_column_access!(
row.into(),
0,
self.num_rows(),
self.as_ref(),
parent,
NodeId
)
sys::tsk_column_access::<NodeId, _, _, _>(row.into(), self.as_ref().parent, self.num_rows())
}

/// Return the ``child`` value from row ``row`` of the table.
Expand All @@ -197,7 +191,7 @@ impl EdgeTable {
/// * `Some(child)` if `u` is valid.
/// * `None` otherwise.
pub fn child<E: Into<EdgeId> + Copy>(&self, row: E) -> Option<NodeId> {
unsafe_tsk_column_access!(row.into(), 0, self.num_rows(), self.as_ref(), child, NodeId)
sys::tsk_column_access::<NodeId, _, _, _>(row.into(), self.as_ref().child, self.num_rows())
}

/// Return the ``left`` value from row ``row`` of the table.
Expand All @@ -207,14 +201,7 @@ impl EdgeTable {
/// * `Some(position)` if `u` is valid.
/// * `None` otherwise.
pub fn left<E: Into<EdgeId> + Copy>(&self, row: E) -> Option<Position> {
unsafe_tsk_column_access!(
row.into(),
0,
self.num_rows(),
self.as_ref(),
left,
Position
)
sys::tsk_column_access::<Position, _, _, _>(row.into(), self.as_ref().left, self.num_rows())
}

/// Return the ``right`` value from row ``row`` of the table.
Expand All @@ -224,7 +211,11 @@ impl EdgeTable {
/// * `Some(position)` if `u` is valid.
/// * `None` otherwise.
pub fn right<E: Into<EdgeId> + Copy>(&self, row: E) -> Option<Position> {
unsafe_tsk_column_access_and_map_into!(row.into(), 0, self.num_rows(), self.as_ref(), right)
sys::tsk_column_access::<Position, _, _, _>(
row.into(),
self.as_ref().right,
self.num_rows(),
)
}

/// Retrieve decoded metadata for a `row`.
Expand Down
7 changes: 6 additions & 1 deletion src/individual_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ptr::NonNull;

use crate::bindings as ll_bindings;
use crate::metadata;
use crate::sys;
use crate::IndividualFlags;
use crate::IndividualId;
use crate::Location;
Expand Down Expand Up @@ -171,7 +172,11 @@ impl IndividualTable {
/// * `Some(flags)` if `row` is valid.
/// * `None` otherwise.
pub fn flags<I: Into<IndividualId> + Copy>(&self, row: I) -> Option<IndividualFlags> {
unsafe_tsk_column_access_and_map_into!(row.into(), 0, self.num_rows(), self.as_ref(), flags)
sys::tsk_column_access::<IndividualFlags, _, _, _>(
row.into(),
self.as_ref().flags,
self.num_rows(),
)
}

/// Return the locations for a given row.
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ mod node_table;
mod population_table;
pub mod prelude;
mod site_table;
mod sys;
mod table_collection;
mod table_iterator;
pub mod table_views;
Expand Down
27 changes: 13 additions & 14 deletions src/migration_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ptr::NonNull;

use crate::bindings as ll_bindings;
use crate::metadata;
use crate::sys;
use crate::Position;
use crate::SizeType;
use crate::Time;
Expand Down Expand Up @@ -199,7 +200,7 @@ impl MigrationTable {
/// * `Some(position)` if `row` is valid.
/// * `None` otherwise.
pub fn left<M: Into<MigrationId> + Copy>(&self, row: M) -> Option<Position> {
unsafe_tsk_column_access_and_map_into!(row.into(), 0, self.num_rows(), self.as_ref(), left)
sys::tsk_column_access::<Position, _, _, _>(row.into(), self.as_ref().left, self.num_rows())
}

/// Return the right coordinate for a given row.
Expand All @@ -209,7 +210,11 @@ impl MigrationTable {
/// * `Some(positions)` if `row` is valid.
/// * `None` otherwise.
pub fn right<M: Into<MigrationId> + Copy>(&self, row: M) -> Option<Position> {
unsafe_tsk_column_access_and_map_into!(row.into(), 0, self.num_rows(), self.as_ref(), right)
sys::tsk_column_access::<Position, _, _, _>(
row.into(),
self.as_ref().right,
self.num_rows(),
)
}

/// Return the node for a given row.
Expand All @@ -219,7 +224,7 @@ impl MigrationTable {
/// * `Some(node)` if `row` is valid.
/// * `None` otherwise.
pub fn node<M: Into<MigrationId> + Copy>(&self, row: M) -> Option<NodeId> {
unsafe_tsk_column_access!(row.into(), 0, self.num_rows(), self.as_ref(), node, NodeId)
sys::tsk_column_access::<NodeId, _, _, _>(row.into(), self.as_ref().node, self.num_rows())
}

/// Return the source population for a given row.
Expand All @@ -229,13 +234,10 @@ impl MigrationTable {
/// * `Some(population)` if `row` is valid.
/// * `None` otherwise.
pub fn source<M: Into<MigrationId> + Copy>(&self, row: M) -> Option<PopulationId> {
unsafe_tsk_column_access!(
sys::tsk_column_access::<PopulationId, _, _, _>(
row.into(),
0,
self.as_ref().source,
self.num_rows(),
self.as_ref(),
source,
PopulationId
)
}

Expand All @@ -246,13 +248,10 @@ impl MigrationTable {
/// * `Some(population)` if `row` is valid.
/// * `None` otherwise.
pub fn dest<M: Into<MigrationId> + Copy>(&self, row: M) -> Option<PopulationId> {
unsafe_tsk_column_access!(
sys::tsk_column_access::<PopulationId, _, _, _>(
row.into(),
0,
self.as_ref().dest,
self.num_rows(),
self.as_ref(),
dest,
PopulationId
)
}

Expand All @@ -263,7 +262,7 @@ impl MigrationTable {
/// * `Some(time)` if `row` is valid.
/// * `None` otherwise.
pub fn time<M: Into<MigrationId> + Copy>(&self, row: M) -> Option<Time> {
unsafe_tsk_column_access_and_map_into!(row.into(), 0, self.num_rows(), self.as_ref(), time)
sys::tsk_column_access::<Time, _, _, _>(row.into(), self.as_ref().time, self.num_rows())
}

/// Retrieve decoded metadata for a `row`.
Expand Down
14 changes: 6 additions & 8 deletions src/mutation_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ptr::NonNull;

use crate::bindings as ll_bindings;
use crate::metadata;
use crate::sys;
use crate::SizeType;
use crate::Time;
use crate::{tsk_id_t, TskitError};
Expand Down Expand Up @@ -196,7 +197,7 @@ impl MutationTable {
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn site<M: Into<MutationId> + Copy>(&self, row: M) -> Option<SiteId> {
unsafe_tsk_column_access!(row.into(), 0, self.num_rows(), self.as_ref(), site, SiteId)
sys::tsk_column_access::<SiteId, _, _, _>(row.into(), self.as_ref().site, self.num_rows())
}

/// Return the ``node`` value from row ``row`` of the table.
Expand All @@ -206,7 +207,7 @@ impl MutationTable {
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn node<M: Into<MutationId> + Copy>(&self, row: M) -> Option<NodeId> {
unsafe_tsk_column_access!(row.into(), 0, self.num_rows(), self.as_ref(), node, NodeId)
sys::tsk_column_access::<NodeId, _, _, _>(row.into(), self.as_ref().node, self.num_rows())
}

/// Return the ``parent`` value from row ``row`` of the table.
Expand All @@ -216,13 +217,10 @@ impl MutationTable {
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn parent<M: Into<MutationId> + Copy>(&self, row: M) -> Option<MutationId> {
unsafe_tsk_column_access!(
sys::tsk_column_access::<MutationId, _, _, _>(
row.into(),
0,
self.as_ref().parent,
self.num_rows(),
self.as_ref(),
parent,
MutationId
)
}

Expand All @@ -233,7 +231,7 @@ impl MutationTable {
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn time<M: Into<MutationId> + Copy>(&self, row: M) -> Option<Time> {
unsafe_tsk_column_access!(row.into(), 0, self.num_rows(), self.as_ref(), time, Time)
sys::tsk_column_access::<Time, _, _, _>(row.into(), self.as_ref().time, self.num_rows())
}

/// Get the ``derived_state`` value from row ``row`` of the table.
Expand Down
23 changes: 11 additions & 12 deletions src/node_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ptr::NonNull;

use crate::bindings as ll_bindings;
use crate::metadata;
use crate::sys;
use crate::NodeFlags;
use crate::SizeType;
use crate::Time;
Expand Down Expand Up @@ -195,7 +196,7 @@ impl NodeTable {
/// # }
/// ```
pub fn time<N: Into<NodeId> + Copy>(&self, row: N) -> Option<Time> {
unsafe_tsk_column_access!(row.into(), 0, self.num_rows(), self.as_ref(), time, Time)
sys::tsk_column_access::<Time, _, _, _>(row.into(), self.as_ref().time, self.num_rows())
}

/// Return the ``flags`` value from row ``row`` of the table.
Expand All @@ -220,7 +221,11 @@ impl NodeTable {
/// # }
/// ```
pub fn flags<N: Into<NodeId> + Copy>(&self, row: N) -> Option<NodeFlags> {
unsafe_tsk_column_access_and_map_into!(row.into(), 0, self.num_rows(), self.as_ref(), flags)
sys::tsk_column_access::<NodeFlags, _, _, _>(
row.into(),
self.as_ref().flags,
self.num_rows(),
)
}

#[deprecated(since = "0.12.0", note = "use flags_slice_mut instead")]
Expand Down Expand Up @@ -265,13 +270,10 @@ impl NodeTable {
/// * `Some(population)` if `row` is valid.
/// * `None` otherwise.
pub fn population<N: Into<NodeId> + Copy>(&self, row: N) -> Option<PopulationId> {
unsafe_tsk_column_access!(
sys::tsk_column_access::<PopulationId, _, _, _>(
row.into(),
0,
self.as_ref().population,
self.num_rows(),
self.as_ref(),
population,
PopulationId
)
}

Expand Down Expand Up @@ -311,13 +313,10 @@ impl NodeTable {
/// * `Some(individual)` if `row` is valid.
/// * `None` otherwise.
pub fn individual<N: Into<NodeId> + Copy>(&self, row: N) -> Option<IndividualId> {
unsafe_tsk_column_access!(
sys::tsk_column_access::<IndividualId, _, _, _>(
row.into(),
0,
self.as_ref().individual,
self.num_rows(),
self.as_ref(),
individual,
IndividualId
)
}

Expand Down
8 changes: 3 additions & 5 deletions src/site_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ptr::NonNull;

use crate::bindings as ll_bindings;
use crate::metadata;
use crate::sys;
use crate::tsk_id_t;
use crate::Position;
use crate::SiteId;
Expand Down Expand Up @@ -166,13 +167,10 @@ impl SiteTable {
/// * `Some(position)` if `row` is valid.
/// * `None` otherwise.
pub fn position<S: Into<SiteId> + Copy>(&self, row: S) -> Option<Position> {
unsafe_tsk_column_access!(
sys::tsk_column_access::<Position, _, _, _>(
row.into(),
0,
self.as_ref().position,
self.num_rows(),
self.as_ref(),
position,
Position
)
}

Expand Down
32 changes: 32 additions & 0 deletions src/sys.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use crate::bindings;

fn tsk_column_access_detail<R: Into<bindings::tsk_id_t>, L: Into<bindings::tsk_size_t>, T: Copy>(
row: R,
column: *const T,
column_length: L,
) -> Option<T> {
let row = row.into();
let column_length = column_length.into();
if row < 0 || (row as crate::tsk_size_t) >= column_length {
None
} else {
assert!(!column.is_null());
// SAFETY: pointer is not null.
// column_length is assumed to come directly
// from a table.
Some(unsafe { *column.offset(row as isize) })
}
}

pub fn tsk_column_access<
O: From<T>,
R: Into<bindings::tsk_id_t>,
L: Into<bindings::tsk_size_t>,
T: Copy,
>(
row: R,
column: *const T,
column_length: L,
) -> Option<O> {
tsk_column_access_detail(row, column, column_length).map(|v| v.into())
}
Loading