@@ -484,8 +484,10 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
484
484
private protected override void SaveModel ( ModelSaveContext ctx ) => _parent . SaveModel ( ctx ) ;
485
485
486
486
protected override Delegate MakeGetter ( DataViewRow input , int iinfo , Func < int , bool > activeOutput , out Action disposer )
487
+ => throw new NotImplementedException ( "This should never be called!" ) ;
488
+
489
+ private Delegate CreateGetter ( DataViewRow input , int iinfo , Func < int , bool > activeOutput , OnnxRuntimeOutputCacher outputCacher )
487
490
{
488
- disposer = null ;
489
491
Host . AssertValue ( input ) ;
490
492
491
493
var activeOutputColNames = _parent . Outputs . Where ( ( x , i ) => activeOutput ( i ) ) . ToArray ( ) ;
@@ -495,26 +497,59 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
495
497
var elemRawType = vectorType . ItemType . RawType ;
496
498
var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _inputColIndices , _inputOnnxTypes , _inputTensorShapes ) ;
497
499
if ( vectorType . ItemType is TextDataViewType )
498
- return MakeStringTensorGetter ( input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
500
+ return MakeStringTensorGetter ( input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
499
501
else
500
- return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
502
+ return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
501
503
}
502
504
else
503
505
{
504
506
var type = _parent . Model . ModelInfo . OutputsInfo [ _parent . MapDataViewColumnToOnnxOutputTensor ( iinfo ) ] . DataViewType . RawType ;
505
507
var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _inputColIndices , _inputOnnxTypes , _inputTensorShapes ) ;
506
- return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
508
+ return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
509
+ }
510
+ }
511
+
512
+ public override Delegate [ ] CreateGetters ( DataViewRow input , Func < int , bool > activeOutput , out Action disposer )
513
+ {
514
+ Contracts . Assert ( input . Schema == InputSchema ) ;
515
+
516
+ OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher ( ) ;
517
+
518
+ int n = OutputColumns . Value . Length ;
519
+ var result = new Delegate [ n ] ;
520
+ for ( int i = 0 ; i < n ; i ++ )
521
+ {
522
+ if ( ! activeOutput ( i ) )
523
+ continue ;
524
+ result [ i ] = CreateGetter ( input , i , activeOutput , outputCacher ) ;
507
525
}
526
+ disposer = ( ) =>
527
+ {
528
+ outputCacher . Dispose ( ) ;
529
+ } ;
530
+ return result ;
508
531
}
509
532
510
- private class OnnxRuntimeOutputCacher
533
+ private sealed class OnnxRuntimeOutputCacher : IDisposable
511
534
{
512
535
public long Position ;
513
- public Dictionary < string , NamedOnnxValue > Outputs ;
536
+ public Dictionary < string , DisposableNamedOnnxValue > Outputs ;
537
+ public IDisposableReadOnlyCollection < DisposableNamedOnnxValue > OutputOnnxValues ;
538
+
514
539
public OnnxRuntimeOutputCacher ( )
515
540
{
516
541
Position = - 1 ;
517
- Outputs = new Dictionary < string , NamedOnnxValue > ( ) ;
542
+ Outputs = new Dictionary < string , DisposableNamedOnnxValue > ( ) ;
543
+ }
544
+
545
+ private bool _isDisposed ;
546
+
547
+ public void Dispose ( )
548
+ {
549
+ if ( _isDisposed )
550
+ return ;
551
+ OutputOnnxValues ? . Dispose ( ) ;
552
+ _isDisposed = true ;
518
553
}
519
554
}
520
555
@@ -529,21 +564,22 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
529
564
inputNameOnnxValues . Add ( srcNamedOnnxValueGetters [ i ] . GetNamedOnnxValue ( ) ) ;
530
565
}
531
566
532
- var outputNamedOnnxValues = _parent . Model . Run ( inputNameOnnxValues ) ;
533
- Contracts . Assert ( outputNamedOnnxValues . Count > 0 ) ;
567
+ outputCache . OutputOnnxValues ? . Dispose ( ) ;
568
+ outputCache . OutputOnnxValues = _parent . Model . Run ( inputNameOnnxValues ) ;
569
+ Contracts . Assert ( outputCache . OutputOnnxValues . Count > 0 ) ;
534
570
535
- foreach ( var outputNameOnnxValue in outputNamedOnnxValues )
571
+ foreach ( var outputNameOnnxValue in outputCache . OutputOnnxValues )
536
572
{
537
573
outputCache . Outputs [ outputNameOnnxValue . Name ] = outputNameOnnxValue ;
538
574
}
539
575
outputCache . Position = position ;
540
576
}
541
577
}
542
578
543
- private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
579
+ private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
580
+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
544
581
{
545
582
Host . AssertValue ( input ) ;
546
- var outputCacher = new OnnxRuntimeOutputCacher ( ) ;
547
583
ValueGetter < VBuffer < T > > valueGetter = ( ref VBuffer < T > dst ) =>
548
584
{
549
585
UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
@@ -558,10 +594,11 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
558
594
return valueGetter ;
559
595
}
560
596
561
- private Delegate MakeStringTensorGetter ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
597
+ private Delegate MakeStringTensorGetter ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
598
+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
562
599
{
563
600
Host . AssertValue ( input ) ;
564
- var outputCacher = new OnnxRuntimeOutputCacher ( ) ;
601
+
565
602
ValueGetter < VBuffer < ReadOnlyMemory < char > > > valueGetter = ( ref VBuffer < ReadOnlyMemory < char > > dst ) =>
566
603
{
567
604
UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
@@ -580,14 +617,15 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
580
617
return valueGetter ;
581
618
}
582
619
583
- private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
620
+ private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
621
+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
584
622
{
585
623
Host . AssertValue ( input ) ;
586
- var outputCache = new OnnxRuntimeOutputCacher ( ) ;
624
+
587
625
ValueGetter < T > valueGetter = ( ref T dst ) =>
588
626
{
589
- UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
590
- var namedOnnxValue = outputCache . Outputs [ _parent . Outputs [ iinfo ] ] ;
627
+ UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
628
+ var namedOnnxValue = outputCacher . Outputs [ _parent . Outputs [ iinfo ] ] ;
591
629
var trueValue = namedOnnxValue . AsEnumerable < NamedOnnxValue > ( ) . Select ( value => value . AsDictionary < string , float > ( ) ) ;
592
630
var caster = _parent . Model . ModelInfo . OutputsInfo [ _parent . MapDataViewColumnToOnnxOutputTensor ( iinfo ) ] . Caster ;
593
631
dst = ( T ) caster ( namedOnnxValue ) ;
0 commit comments