19
19
using Tensorflow ;
20
20
using static Microsoft . ML . TensorFlow . TensorFlowUtils ;
21
21
using static Tensorflow . Binding ;
22
+ using Utils = Microsoft . ML . Internal . Utilities . Utils ;
22
23
23
24
[ assembly: LoadableClass ( DnnRetrainTransformer . Summary , typeof ( IDataTransform ) , typeof ( DnnRetrainTransformer ) ,
24
25
typeof ( DnnRetrainEstimator . Options ) , typeof ( SignatureDataTransform ) , DnnRetrainTransformer . UserName , DnnRetrainTransformer . ShortName ) ]
@@ -607,15 +608,15 @@ internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status
607
608
new ObjectDisposedException ( nameof ( graph ) ) ;
608
609
609
610
var cstatus = status == null ? new Status ( ) : status ;
610
- var n = c_api . TF_GraphGetTensorNumDims ( graph , output , cstatus ) ;
611
+ var n = c_api . TF_GraphGetTensorNumDims ( graph , output , cstatus . Handle ) ;
611
612
612
613
cstatus . Check ( ) ;
613
614
614
615
if ( n == - 1 )
615
616
return new TensorShape ( new int [ 0 ] ) ;
616
617
617
618
var dims = new long [ n ] ;
618
- c_api . TF_GraphGetTensorShape ( graph , output , dims , dims . Length , cstatus ) ;
619
+ c_api . TF_GraphGetTensorShape ( graph , output , dims , dims . Length , cstatus . Handle ) ;
619
620
cstatus . Check ( ) ;
620
621
return new TensorShape ( dims . Select ( x => ( int ) x ) . ToArray ( ) ) ;
621
622
}
@@ -1040,49 +1041,11 @@ public Tensor GetBufferedBatchTensor()
1040
1041
}
1041
1042
else
1042
1043
{
1043
- var tensor = CastDataAndReturnAsTensor ( _bufferedData ) ;
1044
+ var tensor = TensorFlowUtils . CastDataAndReturnAsTensor ( _bufferedData , _tfShape ) ;
1044
1045
_position = 0 ;
1045
1046
return tensor ;
1046
1047
}
1047
1048
}
1048
-
1049
- private Tensor CastDataAndReturnAsTensor ( T [ ] data )
1050
- {
1051
- if ( typeof ( T ) == typeof ( sbyte ) )
1052
- return new Tensor ( ( sbyte [ ] ) ( object ) data , _dims , TF_DataType . TF_INT8 ) ;
1053
- else if ( typeof ( T ) == typeof ( long ) )
1054
- return new Tensor ( ( long [ ] ) ( object ) data , _dims , TF_DataType . TF_INT64 ) ;
1055
- else if ( typeof ( T ) == typeof ( Int32 ) )
1056
- return new Tensor ( ( Int32 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT32 ) ;
1057
- else if ( typeof ( T ) == typeof ( Int16 ) )
1058
- return new Tensor ( ( Int16 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT16 ) ;
1059
- else if ( typeof ( T ) == typeof ( byte ) )
1060
- return new Tensor ( ( byte [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT8 ) ;
1061
- else if ( typeof ( T ) == typeof ( ulong ) )
1062
- return new Tensor ( ( ulong [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT64 ) ;
1063
- else if ( typeof ( T ) == typeof ( UInt32 ) )
1064
- return new Tensor ( ( UInt32 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT32 ) ;
1065
- else if ( typeof ( T ) == typeof ( UInt16 ) )
1066
- return new Tensor ( ( UInt16 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT16 ) ;
1067
- else if ( typeof ( T ) == typeof ( bool ) )
1068
- return new Tensor ( ( bool [ ] ) ( object ) data , _dims , TF_DataType . TF_BOOL ) ;
1069
- else if ( typeof ( T ) == typeof ( float ) )
1070
- return new Tensor ( ( float [ ] ) ( object ) data , _dims , TF_DataType . TF_FLOAT ) ;
1071
- else if ( typeof ( T ) == typeof ( float ) )
1072
- return new Tensor ( ( double [ ] ) ( object ) data , _dims , TF_DataType . TF_DOUBLE ) ;
1073
- else if ( typeof ( T ) == typeof ( ReadOnlyMemory < char > ) )
1074
- {
1075
- byte [ ] [ ] bytes = new byte [ _bufferedData . Length ] [ ] ;
1076
- for ( int i = 0 ; i < bytes . Length ; i ++ )
1077
- {
1078
- bytes [ i ] = Encoding . UTF8 . GetBytes ( ( ( ReadOnlyMemory < char > ) ( object ) data [ i ] ) . ToArray ( ) ) ;
1079
- }
1080
-
1081
- return new Tensor ( bytes , _tfShape . dims . Select ( x => ( long ) x ) . ToArray ( ) ) ;
1082
- }
1083
-
1084
- return new Tensor ( new NDArray ( data , _tfShape ) ) ;
1085
- }
1086
1049
}
1087
1050
1088
1051
private class TensorValueGetterVec < T > : ITensorValueGetter
@@ -1126,45 +1089,7 @@ public Tensor GetTensor()
1126
1089
// This is done to reduce memory allocation every time tensor is created.
1127
1090
_denseData = new T [ _vBuffer . Length ] ;
1128
1091
_vBuffer . CopyTo ( _denseData ) ;
1129
- return CastDataAndReturnAsTensor ( _denseData ) ;
1130
- }
1131
-
1132
- private Tensor CastDataAndReturnAsTensor ( T [ ] data )
1133
- {
1134
- if ( typeof ( T ) == typeof ( sbyte ) )
1135
- return new Tensor ( ( sbyte [ ] ) ( object ) data , _dims , TF_DataType . TF_INT8 ) ;
1136
- else if ( typeof ( T ) == typeof ( long ) )
1137
- return new Tensor ( ( long [ ] ) ( object ) data , _dims , TF_DataType . TF_INT64 ) ;
1138
- else if ( typeof ( T ) == typeof ( Int32 ) )
1139
- return new Tensor ( ( Int32 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT32 ) ;
1140
- else if ( typeof ( T ) == typeof ( Int16 ) )
1141
- return new Tensor ( ( Int16 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT16 ) ;
1142
- else if ( typeof ( T ) == typeof ( byte ) )
1143
- return new Tensor ( ( byte [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT8 ) ;
1144
- else if ( typeof ( T ) == typeof ( ulong ) )
1145
- return new Tensor ( ( ulong [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT64 ) ;
1146
- else if ( typeof ( T ) == typeof ( UInt32 ) )
1147
- return new Tensor ( ( UInt32 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT32 ) ;
1148
- else if ( typeof ( T ) == typeof ( UInt16 ) )
1149
- return new Tensor ( ( UInt16 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT16 ) ;
1150
- else if ( typeof ( T ) == typeof ( bool ) )
1151
- return new Tensor ( ( bool [ ] ) ( object ) data , _dims , TF_DataType . TF_BOOL ) ;
1152
- else if ( typeof ( T ) == typeof ( float ) )
1153
- return new Tensor ( ( float [ ] ) ( object ) data , _dims , TF_DataType . TF_FLOAT ) ;
1154
- else if ( typeof ( T ) == typeof ( double ) )
1155
- return new Tensor ( ( double [ ] ) ( object ) data , _dims , TF_DataType . TF_DOUBLE ) ;
1156
- else if ( typeof ( T ) == typeof ( ReadOnlyMemory < char > ) )
1157
- {
1158
- byte [ ] [ ] bytes = new byte [ _vBuffer . Length ] [ ] ;
1159
- for ( int i = 0 ; i < bytes . Length ; i ++ )
1160
- {
1161
- bytes [ i ] = Encoding . UTF8 . GetBytes ( ( ( ReadOnlyMemory < char > ) ( object ) data [ i ] ) . ToArray ( ) ) ;
1162
- }
1163
-
1164
- return new Tensor ( bytes , _tfShape . dims . Select ( x => ( long ) x ) . ToArray ( ) ) ;
1165
- }
1166
-
1167
- return new Tensor ( new NDArray ( data , _tfShape ) ) ;
1092
+ return TensorFlowUtils . CastDataAndReturnAsTensor ( _denseData , _tfShape ) ;
1168
1093
}
1169
1094
1170
1095
public void BufferTrainingData ( )
@@ -1177,7 +1102,7 @@ public void BufferTrainingData()
1177
1102
public Tensor GetBufferedBatchTensor ( )
1178
1103
{
1179
1104
_position = 0 ;
1180
- var tensor = CastDataAndReturnAsTensor ( _bufferedData ) ;
1105
+ var tensor = TensorFlowUtils . CastDataAndReturnAsTensor ( _bufferedData , _tfShape ) ;
1181
1106
_bufferedData = new T [ _bufferedDataSize ] ;
1182
1107
return tensor ;
1183
1108
}
0 commit comments