diff --git a/build/Dependencies.props b/build/Dependencies.props index 197fa167b2..d2a48f6865 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -8,7 +8,6 @@ <SystemMemoryVersion>4.5.1</SystemMemoryVersion> <SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion> <SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion> - <SystemComponentModelCompositionVersion>4.5.0</SystemComponentModelCompositionVersion> </PropertyGroup> <!-- Other/Non-Core Product Dependencies --> diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index da770fd20b..f205684d3a 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -970,27 +970,27 @@ Please note that you need to make your `mapping` operation into a 'pure function - It should not have side effects (we may call it arbitrarily at any time, or omit the call) One important caveat is: if you want your custom transformation to be part of your saved model, you will need to provide a `contractName` for it. -At loading time, you will need to reconstruct the custom transformer and inject it into MLContext. +At loading time, you will need to register the custom transformer with the MLContext. Here is a complete example that saves and loads a model with a custom mapping. ```csharp /// <summary> -/// One class that contains all custom mappings that we need for our model. +/// One class that contains the custom mapping functionality that we need for our model. +/// +/// It has a <see cref="CustomMappingFactoryAttributeAttribute"/> on it and +/// derives from <see cref="CustomMappingFactory{TSrc, TDst}"/>. /// </summary> -public class CustomMappings +[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))] +public class CustomMappings : CustomMappingFactory<InputRow, OutputRow> { // This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading. public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000; - // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate - // this property. - [Import] - public MLContext MLContext { get; set; } - - // We are exporting the custom transformer by the name 'IncomeMapping'. - [Export(nameof(IncomeMapping))] - public ITransformer MyCustomTransformer - => MLContext.Transforms.CustomMappingTransformer<InputRow, OutputRow>(IncomeMapping, nameof(IncomeMapping)); + // This factory method will be called when loading the model to get the mapping operation. + public override Action<InputRow, OutputRow> GetMapping() + { + return IncomeMapping; + } } ``` @@ -1013,8 +1013,9 @@ using (var fs = File.Create(modelPath)) // Now pretend we are in a different process. -// Create a custom composition container for all our custom mapping actions. -newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings))); +// Register the assembly that contains 'CustomMappings' with the ComponentCatalog +// so it can be found when loading the model. +newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly); // Now we can load the model. ITransformer loadedModel; diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj index 9e92abc840..1443a9f6b0 100644 --- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj +++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj @@ -15,7 +15,6 @@ <PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" /> <PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" /> <PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" /> - <PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" /> </ItemGroup> <ItemGroup> diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index 1dd10e9f39..0c28b6c099 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -35,6 +35,8 @@ internal ComponentCatalog() _entryPointMap = new Dictionary<string, EntryPointInfo>(); _componentMap = new Dictionary<string, ComponentInfo>(); _components = new List<ComponentInfo>(); + + _extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>(); } /// <summary> @@ -404,6 +406,8 @@ internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcMo private readonly List<ComponentInfo> _components; private readonly Dictionary<string, ComponentInfo> _componentMap; + private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap; + private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes, out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment) { @@ -618,6 +622,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true) AddClass(info, attr.LoadNames, throwOnError); } + + LoadExtensions(assembly, throwOnError); } } } @@ -980,5 +986,75 @@ private static void ParseArguments(IHostEnvironment env, object args, string set if (errorMsg != null) throw Contracts.Except(errorMsg); } + + private void LoadExtensions(Assembly assembly, bool throwOnError) + { + // don't waste time looking through all the types of an assembly + // that can't contain extensions + if (CanContainExtensions(assembly)) + { + foreach (Type type in assembly.GetTypes()) + { + if (type.IsClass) + { + foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute))) + { + var key = (AttributeType: attribute.GetType(), attribute.ContractName); + if (_extensionsMap.TryGetValue(key, out var existingType)) + { + if (throwOnError) + { + throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog."); + } + } + else + { + _extensionsMap.Add(key, type); + } + } + } + } + } + } + + /// <summary> + /// Gets a value indicating whether <paramref name="assembly"/> can contain extensions. + /// </summary> + /// <remarks> + /// All ML.NET product assemblies won't contain extensions. + /// </remarks> + private static bool CanContainExtensions(Assembly assembly) + { + if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal) + && HasMLNetPublicKey(assembly)) + { + return false; + } + + return true; + } + + private static bool HasMLNetPublicKey(Assembly assembly) + { + return assembly.GetName().GetPublicKey().SequenceEqual( + typeof(ComponentCatalog).Assembly.GetName().GetPublicKey()); + } + + [BestFriend] + internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName) + { + object exportedValue = null; + if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType)) + { + exportedValue = Activator.CreateInstance(extensionType); + } + + if (exportedValue == null) + { + throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'."); + } + + return exportedValue; + } } } diff --git a/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs new file mode 100644 index 0000000000..7d8d00a252 --- /dev/null +++ b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML +{ + /// <summary> + /// The base attribute type for all attributes used for extensibility purposes. + /// </summary> + [AttributeUsage(AttributeTargets.Class)] + public abstract class ExtensionBaseAttribute : Attribute + { + public string ContractName { get; } + + [BestFriend] + private protected ExtensionBaseAttribute(string contractName) + { + ContractName = contractName; + } + } +} diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index 0f095d02c5..78cd697811 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition.Hosting; namespace Microsoft.ML { @@ -92,12 +91,6 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider [Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " + "Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")] IFileHandle CreateTempFile(string suffix = null, string prefix = null); - - /// <summary> - /// Get the MEF composition container. This can be used to instantiate user-provided 'parts' when the model - /// is being loaded, or the components are otherwise created via dependency injection. - /// </summary> - CompositionContainer GetCompositionContainer(); } /// <summary> diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 89ce4503c4..1127e9b115 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.ComponentModel.Composition.Hosting; using System.IO; namespace Microsoft.ML.Data @@ -632,7 +631,5 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo else if (!removeLastNewLine) writer.WriteLine(); } - - public virtual CompositionContainer GetCompositionContainer() => new CompositionContainer(); } } diff --git a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj index 0d6b288499..ccd18a42b7 100644 --- a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj +++ b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj @@ -12,7 +12,6 @@ <ProjectReference Include="..\Microsoft.Data.DataView\Microsoft.Data.DataView.csproj" /> <PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" /> - <PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" /> <PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" /> </ItemGroup> diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index a466e37aa3..39281b36d6 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -3,8 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition; -using System.ComponentModel.Composition.Hosting; using Microsoft.ML.Data; namespace Microsoft.ML @@ -69,9 +67,9 @@ public sealed class MLContext : IHostEnvironment public event EventHandler<LoggingEventArgs> Log; /// <summary> - /// This is a MEF composition container catalog to be used for model loading. + /// This is a catalog of components that will be used for model loading. /// </summary> - public CompositionContainer CompositionContainer { get; set; } + public ComponentCatalog ComponentCatalog => _env.ComponentCatalog; /// <summary> /// Create the ML context. @@ -80,7 +78,7 @@ public sealed class MLContext : IHostEnvironment /// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param> public MLContext(int? seed = null, int conc = 0) { - _env = new LocalEnvironment(seed, conc, MakeCompositionContainer); + _env = new LocalEnvironment(seed, conc); _env.AddListener(ProcessMessage); BinaryClassification = new BinaryClassificationCatalog(_env); @@ -94,18 +92,6 @@ public MLContext(int? seed = null, int conc = 0) Data = new DataOperationsCatalog(_env); } - private CompositionContainer MakeCompositionContainer() - { - if (CompositionContainer == null) - return null; - - var mlContext = CompositionContainer.GetExportedValueOrDefault<MLContext>(); - if (mlContext == null) - CompositionContainer.ComposeExportedValue<MLContext>(this); - - return CompositionContainer; - } - private void ProcessMessage(IMessageSource source, ChannelMessage message) { var log = Log; @@ -120,7 +106,6 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor; bool IHostEnvironment.IsCancelled => _env.IsCancelled; - ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog; string IExceptionContext.ContextDescription => _env.ContextDescription; IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix); TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex); @@ -128,6 +113,5 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) IChannel IChannelProvider.Start(string name) => _env.Start(name); IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name); IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); - CompositionContainer IHostEnvironment.GetCompositionContainer() => _env.GetCompositionContainer(); } } diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs index a9e3a4a80e..b17b2cf39f 100644 --- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs +++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition.Hosting; namespace Microsoft.ML.Data { @@ -14,8 +13,6 @@ namespace Microsoft.ML.Data /// </summary> internal sealed class LocalEnvironment : HostEnvironmentBase<LocalEnvironment> { - private readonly Func<CompositionContainer> _compositionContainerFactory; - private sealed class Channel : ChannelBase { public readonly Stopwatch Watch; @@ -49,11 +46,9 @@ protected override void Dispose(bool disposing) /// </summary> /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param> /// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param> - /// <param name="compositionContainerFactory">The function to retrieve the composition container</param> - public LocalEnvironment(int? seed = null, int conc = 0, Func<CompositionContainer> compositionContainerFactory = null) + public LocalEnvironment(int? seed = null, int conc = 0) : base(RandomUtils.Create(seed), verbose: false, conc) { - _compositionContainerFactory = compositionContainerFactory; } /// <summary> @@ -96,13 +91,6 @@ protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase pare return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>()); } - public override CompositionContainer GetCompositionContainer() - { - if (_compositionContainerFactory != null) - return _compositionContainerFactory(); - return base.GetCompositionContainer(); - } - private sealed class Host : HostBase { public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose, int? conc) diff --git a/src/Microsoft.ML.Transforms/CustomMappingFactory.cs b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs new file mode 100644 index 0000000000..f1778f72f7 --- /dev/null +++ b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.Data.DataView; + +namespace Microsoft.ML.Transforms +{ + /// <summary> + /// Place this attribute onto a type to cause it to be considered a custom mapping factory. + /// </summary> + [AttributeUsage(AttributeTargets.Class)] + public sealed class CustomMappingFactoryAttributeAttribute : ExtensionBaseAttribute + { + public CustomMappingFactoryAttributeAttribute(string contractName) + : base(contractName) + { + } + } + + internal interface ICustomMappingFactory + { + ITransformer CreateTransformer(IHostEnvironment env, string contractName); + } + + /// <summary> + /// The base type for custom mapping factories. + /// </summary> + /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the input <see cref="IDataView"/>.</typeparam> + /// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam> + public abstract class CustomMappingFactory<TSrc, TDst> : ICustomMappingFactory + where TSrc : class, new() + where TDst : class, new() + { + /// <summary> + /// Returns the mapping delegate that maps from <typeparamref name="TSrc"/> inputs to <typeparamref name="TDst"/> outputs. + /// </summary> + public abstract Action<TSrc, TDst> GetMapping(); + + ITransformer ICustomMappingFactory.CreateTransformer(IHostEnvironment env, string contractName) + { + Action<TSrc, TDst> mapAction = GetMapping(); + return new CustomMappingTransformer<TSrc, TDst>(env, mapAction, contractName); + } + } +} diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 2f847797f7..6950d9f46a 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -68,11 +68,13 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx) var contractName = ctx.LoadString(); - var composition = env.GetCompositionContainer(); - if (composition == null) - throw Contracts.Except("Unable to get the MEF composition container"); - ITransformer transformer = composition.GetExportedValue<ITransformer>(contractName); - return transformer; + object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName); + if (!(factoryObject is ICustomMappingFactory mappingFactory)) + { + throw env.Except($"The class with contract '{contractName}' must derive from '{typeof(CustomMappingFactory<,>).FullName}'."); + } + + return mappingFactory.CreateTransformer(env, contractName); } /// <summary> diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index 4cf6548d6c..216f2c3179 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -38,7 +38,6 @@ #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll" #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll" #r "System" -#r "System.ComponentModel.Composition" #r "System.Core" #r "System.Xml.Linq" diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index ce9a9e0385..cf37e1bb70 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -4,8 +4,6 @@ using System; using System.Collections.Generic; -using System.ComponentModel.Composition; -using System.ComponentModel.Composition.Hosting; using System.IO; using System.Linq; using Microsoft.Data.DataView; @@ -13,6 +11,7 @@ using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Categorical; using Microsoft.ML.Transforms.Normalizers; using Microsoft.ML.Transforms.Text; @@ -491,22 +490,22 @@ public void CustomTransformer() } /// <summary> - /// One class that contains all custom mappings that we need for our model. + /// One class that contains the custom mapping functionality that we need for our model. + /// + /// It has a <see cref="CustomMappingFactoryAttributeAttribute"/> on it and + /// derives from <see cref="CustomMappingFactory{TSrc, TDst}"/>. /// </summary> - public class CustomMappings + [CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))] + public class CustomMappings : CustomMappingFactory<InputRow, OutputRow> { // This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading. public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000; - // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate - // this property. - [Import] - public MLContext MLContext { get; set; } - - // We are exporting the custom transformer by the name 'IncomeMapping'. - [Export(nameof(IncomeMapping))] - public ITransformer MyCustomTransformer - => MLContext.Transforms.CustomMappingTransformer<InputRow, OutputRow>(IncomeMapping, nameof(IncomeMapping)); + // This factory method will be called when loading the model to get the mapping operation. + public override Action<InputRow, OutputRow> GetMapping() + { + return IncomeMapping; + } } private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string modelPath) @@ -530,8 +529,9 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string // Now pretend we are in a different process. var newContext = new MLContext(); - // Create a custom composition container for all our custom mapping actions. - newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings))); + // Register the assembly that contains 'CustomMappings' with the ComponentCatalog + // so it can be found when loading the model. + newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly); // Now we can load the model. ITransformer loadedModel; diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs index 9a1c048ebf..3dcfcc991a 100644 --- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs @@ -3,8 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.ComponentModel.Composition; -using System.ComponentModel.Composition.Hosting; using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Data; @@ -32,18 +30,18 @@ public class MyOutput public string Together { get; set; } } - public class MyLambda + [CustomMappingFactoryAttribute("MyLambda")] + public class MyLambda : CustomMappingFactory<MyInput, MyOutput> { - [Export("MyLambda")] - public ITransformer MyTransformer => ML.Transforms.CustomMappingTransformer<MyInput, MyOutput>(MyAction, "MyLambda"); - - [Import] - public MLContext ML { get; set; } - public static void MyAction(MyInput input, MyOutput output) { output.Together = $"{input.Float1} + {string.Join(", ", input.Float4)}"; } + + public override Action<MyInput, MyOutput> GetMapping() + { + return MyAction; + } } [Fact] @@ -67,14 +65,14 @@ public void TestCustomTransformer() try { TestEstimatorCore(customEst, data); - Assert.True(false, "Cannot work without MEF injection"); + Assert.True(false, "Cannot work without RegisterAssembly"); } catch (InvalidOperationException ex) { if (!ex.IsMarked()) throw; } - ML.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(MyLambda))); + ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly); TestEstimatorCore(customEst, data); transformedData = customEst.Fit(data).Transform(data);