Skip to content

Commit 8f455b0

Browse files
committed
Embed typecheck into the downcaster
1 parent 111e3a4 commit 8f455b0

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

src/database.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ impl dyn Database {
123123
#[track_caller]
124124
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
125125
let views = self.zalsa().views();
126-
views.assert_database(self);
127-
// SAFETY: We've asserted that the database is correct.
128-
unsafe { views.downcaster_for()(self) }
126+
views.downcaster_for().downcast(self)
129127
}
130128
}

src/function.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ where
212212
input: Id,
213213
revision: Revision,
214214
) -> MaybeChangedAfter {
215-
db.zalsa().views().assert_database(db);
216-
// SAFETY: We've asserted that the database is correct.
217-
let db = unsafe { (self.view_caster)(db) };
215+
let db = self.view_caster.downcast(db);
218216
self.maybe_changed_after(db, input, revision)
219217
}
220218

@@ -274,9 +272,7 @@ where
274272
db: &'db dyn Database,
275273
key_index: Id,
276274
) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) {
277-
db.zalsa().views().assert_database(db);
278-
// SAFETY: We've asserted that the database is correct.
279-
let db = unsafe { (self.view_caster)(db) };
275+
let db = self.view_caster.downcast(db);
280276
self.accumulated_map(db, key_index)
281277
}
282278
}

src/views.rs

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,29 @@ struct DynViewCaster {
3838
type_name: &'static str,
3939

4040
/// Type-erased `ViewCaster::<Db, DbView>::vtable_cast`.
41-
cast: ErasedDatabaseDownCaster,
41+
cast: ErasedDatabaseDownCasterSig,
4242
}
4343

44-
type ErasedDatabaseDownCaster = unsafe fn(&dyn Database) -> *const ();
45-
pub type DatabaseDownCaster<DbView> = unsafe fn(&dyn Database) -> &DbView;
44+
type ErasedDatabaseDownCasterSig = unsafe fn(&dyn Database) -> *const ();
45+
type DatabaseDownCasterSig<DbView> = unsafe fn(&dyn Database) -> &DbView;
46+
47+
pub struct DatabaseDownCaster<DbView: ?Sized>(TypeId, DatabaseDownCasterSig<DbView>);
48+
49+
impl<DbView: ?Sized + Any> DatabaseDownCaster<DbView> {
50+
pub fn downcast<'db>(&self, db: &'db dyn Database) -> &'db DbView {
51+
assert_eq!(
52+
self.0,
53+
db.type_id(),
54+
"Database type does not match the expected type for this `Views` instance"
55+
);
56+
// SAFETY: We've asserted that the database is correct.
57+
unsafe { (self.1)(db) }
58+
}
59+
60+
pub unsafe fn downcast_unchecked<'db>(&self, db: &'db dyn Database) -> &'db DbView {
61+
unsafe { (self.1)(db) }
62+
}
63+
}
4664

4765
impl Views {
4866
pub(crate) fn new<Db: Database>() -> Self {
@@ -53,9 +71,10 @@ impl Views {
5371
target_type_id: TypeId::of::<dyn Database>(),
5472
type_name: std::any::type_name::<dyn Database>(),
5573
cast: unsafe {
56-
std::mem::transmute::<DatabaseDownCaster<dyn Database>, ErasedDatabaseDownCaster>(
57-
|db| db,
58-
)
74+
std::mem::transmute::<
75+
DatabaseDownCasterSig<dyn Database>,
76+
ErasedDatabaseDownCasterSig,
77+
>(|db| db)
5978
},
6079
});
6180
Self {
@@ -65,7 +84,7 @@ impl Views {
6584
}
6685

6786
/// Add a new downcaster from `dyn Database` to `dyn DbView`.
68-
pub fn add<DbView: ?Sized + Any>(&self, func: DatabaseDownCaster<DbView>) {
87+
pub fn add<DbView: ?Sized + Any>(&self, func: DatabaseDownCasterSig<DbView>) {
6988
let target_type_id = TypeId::of::<DbView>();
7089
if self
7190
.view_casters
@@ -78,7 +97,9 @@ impl Views {
7897
target_type_id,
7998
type_name: std::any::type_name::<DbView>(),
8099
cast: unsafe {
81-
std::mem::transmute::<DatabaseDownCaster<DbView>, ErasedDatabaseDownCaster>(func)
100+
std::mem::transmute::<DatabaseDownCasterSig<DbView>, ErasedDatabaseDownCasterSig>(
101+
func,
102+
)
82103
},
83104
});
84105
}
@@ -92,11 +113,11 @@ impl Views {
92113
let view_type_id = TypeId::of::<DbView>();
93114
for (_idx, view) in self.view_casters.iter() {
94115
if view.target_type_id == view_type_id {
95-
return unsafe {
96-
std::mem::transmute::<ErasedDatabaseDownCaster, DatabaseDownCaster<DbView>>(
116+
return DatabaseDownCaster(self.source_type_id, unsafe {
117+
std::mem::transmute::<ErasedDatabaseDownCasterSig, DatabaseDownCasterSig<DbView>>(
97118
view.cast,
98119
)
99-
};
120+
});
100121
}
101122
}
102123

@@ -105,14 +126,6 @@ impl Views {
105126
std::any::type_name::<DbView>(),
106127
);
107128
}
108-
109-
pub fn assert_database(&self, db: &dyn Database) {
110-
assert_eq!(
111-
self.source_type_id,
112-
db.type_id(),
113-
"Database type does not match the expected type for this `Views` instance"
114-
);
115-
}
116129
}
117130

118131
impl std::fmt::Debug for Views {

0 commit comments

Comments
 (0)