@@ -38,11 +38,29 @@ struct DynViewCaster {
38
38
type_name : & ' static str ,
39
39
40
40
/// Type-erased `ViewCaster::<Db, DbView>::vtable_cast`.
41
- cast : ErasedDatabaseDownCaster ,
41
+ cast : ErasedDatabaseDownCasterSig ,
42
42
}
43
43
44
- type ErasedDatabaseDownCaster = unsafe fn ( & dyn Database ) -> * const ( ) ;
45
- pub type DatabaseDownCaster < DbView > = unsafe fn ( & dyn Database ) -> & DbView ;
44
+ type ErasedDatabaseDownCasterSig = unsafe fn ( & dyn Database ) -> * const ( ) ;
45
+ type DatabaseDownCasterSig < DbView > = unsafe fn ( & dyn Database ) -> & DbView ;
46
+
47
+ pub struct DatabaseDownCaster < DbView : ?Sized > ( TypeId , DatabaseDownCasterSig < DbView > ) ;
48
+
49
+ impl < DbView : ?Sized + Any > DatabaseDownCaster < DbView > {
50
+ pub fn downcast < ' db > ( & self , db : & ' db dyn Database ) -> & ' db DbView {
51
+ assert_eq ! (
52
+ self . 0 ,
53
+ db. type_id( ) ,
54
+ "Database type does not match the expected type for this `Views` instance"
55
+ ) ;
56
+ // SAFETY: We've asserted that the database is correct.
57
+ unsafe { ( self . 1 ) ( db) }
58
+ }
59
+
60
+ pub unsafe fn downcast_unchecked < ' db > ( & self , db : & ' db dyn Database ) -> & ' db DbView {
61
+ unsafe { ( self . 1 ) ( db) }
62
+ }
63
+ }
46
64
47
65
impl Views {
48
66
pub ( crate ) fn new < Db : Database > ( ) -> Self {
@@ -53,9 +71,10 @@ impl Views {
53
71
target_type_id : TypeId :: of :: < dyn Database > ( ) ,
54
72
type_name : std:: any:: type_name :: < dyn Database > ( ) ,
55
73
cast : unsafe {
56
- std:: mem:: transmute :: < DatabaseDownCaster < dyn Database > , ErasedDatabaseDownCaster > (
57
- |db| db,
58
- )
74
+ std:: mem:: transmute :: <
75
+ DatabaseDownCasterSig < dyn Database > ,
76
+ ErasedDatabaseDownCasterSig ,
77
+ > ( |db| db)
59
78
} ,
60
79
} ) ;
61
80
Self {
@@ -65,7 +84,7 @@ impl Views {
65
84
}
66
85
67
86
/// Add a new downcaster from `dyn Database` to `dyn DbView`.
68
- pub fn add < DbView : ?Sized + Any > ( & self , func : DatabaseDownCaster < DbView > ) {
87
+ pub fn add < DbView : ?Sized + Any > ( & self , func : DatabaseDownCasterSig < DbView > ) {
69
88
let target_type_id = TypeId :: of :: < DbView > ( ) ;
70
89
if self
71
90
. view_casters
@@ -78,7 +97,9 @@ impl Views {
78
97
target_type_id,
79
98
type_name : std:: any:: type_name :: < DbView > ( ) ,
80
99
cast : unsafe {
81
- std:: mem:: transmute :: < DatabaseDownCaster < DbView > , ErasedDatabaseDownCaster > ( func)
100
+ std:: mem:: transmute :: < DatabaseDownCasterSig < DbView > , ErasedDatabaseDownCasterSig > (
101
+ func,
102
+ )
82
103
} ,
83
104
} ) ;
84
105
}
@@ -92,11 +113,11 @@ impl Views {
92
113
let view_type_id = TypeId :: of :: < DbView > ( ) ;
93
114
for ( _idx, view) in self . view_casters . iter ( ) {
94
115
if view. target_type_id == view_type_id {
95
- return unsafe {
96
- std:: mem:: transmute :: < ErasedDatabaseDownCaster , DatabaseDownCaster < DbView > > (
116
+ return DatabaseDownCaster ( self . source_type_id , unsafe {
117
+ std:: mem:: transmute :: < ErasedDatabaseDownCasterSig , DatabaseDownCasterSig < DbView > > (
97
118
view. cast ,
98
119
)
99
- } ;
120
+ } ) ;
100
121
}
101
122
}
102
123
@@ -105,14 +126,6 @@ impl Views {
105
126
std:: any:: type_name:: <DbView >( ) ,
106
127
) ;
107
128
}
108
-
109
- pub fn assert_database ( & self , db : & dyn Database ) {
110
- assert_eq ! (
111
- self . source_type_id,
112
- db. type_id( ) ,
113
- "Database type does not match the expected type for this `Views` instance"
114
- ) ;
115
- }
116
129
}
117
130
118
131
impl std:: fmt:: Debug for Views {
0 commit comments