diff --git a/build/Dependencies.props b/build/Dependencies.props
index 197fa167b2..d2a48f6865 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -8,7 +8,6 @@
     <SystemMemoryVersion>4.5.1</SystemMemoryVersion>
     <SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
     <SystemThreadingTasksDataflowPackageVersion>4.8.0</SystemThreadingTasksDataflowPackageVersion>
-    <SystemComponentModelCompositionVersion>4.5.0</SystemComponentModelCompositionVersion>
   </PropertyGroup>
 
   <!-- Other/Non-Core Product Dependencies -->
diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md
index da770fd20b..f205684d3a 100644
--- a/docs/code/MlNetCookBook.md
+++ b/docs/code/MlNetCookBook.md
@@ -970,27 +970,27 @@ Please note that you need to make your `mapping` operation into a 'pure function
 - It should not have side effects (we may call it arbitrarily at any time, or omit the call)
 
 One important caveat is: if you want your custom transformation to be part of your saved model, you will need to provide a `contractName` for it.
-At loading time, you will need to reconstruct the custom transformer and inject it into MLContext. 
+At loading time, you will need to register the custom transformer with the MLContext. 
 
 Here is a complete example that saves and loads a model with a custom mapping.
 ```csharp
 /// <summary>
-/// One class that contains all custom mappings that we need for our model.
+/// One class that contains the custom mapping functionality that we need for our model.
+/// 
+/// It has a <see cref="CustomMappingFactoryAttributeAttribute"/> on it and
+/// derives from <see cref="CustomMappingFactory{TSrc, TDst}"/>.
 /// </summary>
-public class CustomMappings
+[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
+public class CustomMappings : CustomMappingFactory<InputRow, OutputRow>
 {
     // This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading.
     public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;
 
-    // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate
-    // this property.
-    [Import]
-    public MLContext MLContext { get; set; }
-
-    // We are exporting the custom transformer by the name 'IncomeMapping'.
-    [Export(nameof(IncomeMapping))]
-    public ITransformer MyCustomTransformer 
-        => MLContext.Transforms.CustomMappingTransformer<InputRow, OutputRow>(IncomeMapping, nameof(IncomeMapping));
+    // This factory method will be called when loading the model to get the mapping operation.
+    public override Action<InputRow, OutputRow> GetMapping()
+    {
+        return IncomeMapping;
+    }
 }
 ```
 
@@ -1013,8 +1013,9 @@ using (var fs = File.Create(modelPath))
 
 // Now pretend we are in a different process.
 
-// Create a custom composition container for all our custom mapping actions.
-newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings)));
+// Register the assembly that contains 'CustomMappings' with the ComponentCatalog
+// so it can be found when loading the model.
+newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
 
 // Now we can load the model.
 ITransformer loadedModel;
diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
index 9e92abc840..1443a9f6b0 100644
--- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
+++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
@@ -15,7 +15,6 @@
     <PackageReference Include="System.CodeDom" Version="$(SystemCodeDomPackageVersion)" />
     <PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
     <PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
-    <PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
   </ItemGroup>
 
   <ItemGroup>
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
index 1dd10e9f39..0c28b6c099 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
@@ -35,6 +35,8 @@ internal ComponentCatalog()
             _entryPointMap = new Dictionary<string, EntryPointInfo>();
             _componentMap = new Dictionary<string, ComponentInfo>();
             _components = new List<ComponentInfo>();
+
+            _extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>();
         }
 
         /// <summary>
@@ -404,6 +406,8 @@ internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcMo
         private readonly List<ComponentInfo> _components;
         private readonly Dictionary<string, ComponentInfo> _componentMap;
 
+        private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap;
+
         private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes,
             out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment)
         {
@@ -618,6 +622,8 @@ public void RegisterAssembly(Assembly assembly, bool throwOnError = true)
 
                         AddClass(info, attr.LoadNames, throwOnError);
                     }
+
+                    LoadExtensions(assembly, throwOnError);
                 }
             }
         }
@@ -980,5 +986,75 @@ private static void ParseArguments(IHostEnvironment env, object args, string set
             if (errorMsg != null)
                 throw Contracts.Except(errorMsg);
         }
+
+        private void LoadExtensions(Assembly assembly, bool throwOnError)
+        {
+            // don't waste time looking through all the types of an assembly
+            // that can't contain extensions
+            if (CanContainExtensions(assembly))
+            {
+                foreach (Type type in assembly.GetTypes())
+                {
+                    if (type.IsClass)
+                    {
+                        foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute)))
+                        {
+                            var key = (AttributeType: attribute.GetType(), attribute.ContractName);
+                            if (_extensionsMap.TryGetValue(key, out var existingType))
+                            {
+                                if (throwOnError)
+                                {
+                                    throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog.");
+                                }
+                            }
+                            else
+                            {
+                                _extensionsMap.Add(key, type);
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
+        /// <summary>
+        /// Gets a value indicating whether <paramref name="assembly"/> can contain extensions.
+        /// </summary>
+        /// <remarks>
+        /// All ML.NET product assemblies won't contain extensions.
+        /// </remarks>
+        private static bool CanContainExtensions(Assembly assembly)
+        {
+            if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal)
+                && HasMLNetPublicKey(assembly))
+            {
+                return false;
+            }
+
+            return true;
+        }
+
+        private static bool HasMLNetPublicKey(Assembly assembly)
+        {
+            return assembly.GetName().GetPublicKey().SequenceEqual(
+                typeof(ComponentCatalog).Assembly.GetName().GetPublicKey());
+        }
+
+        [BestFriend]
+        internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName)
+        {
+            object exportedValue = null;
+            if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType))
+            {
+                exportedValue = Activator.CreateInstance(extensionType);
+            }
+
+            if (exportedValue == null)
+            {
+                throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'.");
+            }
+
+            return exportedValue;
+        }
     }
 }
diff --git a/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs
new file mode 100644
index 0000000000..7d8d00a252
--- /dev/null
+++ b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs
@@ -0,0 +1,23 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML
+{
+    /// <summary>
+    /// The base attribute type for all attributes used for extensibility purposes.
+    /// </summary>
+    [AttributeUsage(AttributeTargets.Class)]
+    public abstract class ExtensionBaseAttribute : Attribute
+    {
+        public string ContractName { get; }
+
+        [BestFriend]
+        private protected ExtensionBaseAttribute(string contractName)
+        {
+            ContractName = contractName;
+        }
+    }
+}
diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index 0f095d02c5..78cd697811 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -3,7 +3,6 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
-using System.ComponentModel.Composition.Hosting;
 
 namespace Microsoft.ML
 {
@@ -92,12 +91,6 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
         [Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " +
             "Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")]
         IFileHandle CreateTempFile(string suffix = null, string prefix = null);
-
-        /// <summary>
-        /// Get the MEF composition container. This can be used to instantiate user-provided 'parts' when the model
-        /// is being loaded, or the components are otherwise created via dependency injection.
-        /// </summary>
-        CompositionContainer GetCompositionContainer();
     }
 
     /// <summary>
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index 89ce4503c4..1127e9b115 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -5,7 +5,6 @@
 using System;
 using System.Collections.Concurrent;
 using System.Collections.Generic;
-using System.ComponentModel.Composition.Hosting;
 using System.IO;
 
 namespace Microsoft.ML.Data
@@ -632,7 +631,5 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
             else if (!removeLastNewLine)
                 writer.WriteLine();
         }
-
-        public virtual CompositionContainer GetCompositionContainer() => new CompositionContainer();
     }
 }
diff --git a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
index 0d6b288499..ccd18a42b7 100644
--- a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
+++ b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
@@ -12,7 +12,6 @@
     <ProjectReference Include="..\Microsoft.Data.DataView\Microsoft.Data.DataView.csproj" />
     
     <PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
-    <PackageReference Include="System.ComponentModel.Composition" Version="$(SystemComponentModelCompositionVersion)" />
     <PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
   </ItemGroup>
 
diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs
index a466e37aa3..39281b36d6 100644
--- a/src/Microsoft.ML.Data/MLContext.cs
+++ b/src/Microsoft.ML.Data/MLContext.cs
@@ -3,8 +3,6 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
-using System.ComponentModel.Composition;
-using System.ComponentModel.Composition.Hosting;
 using Microsoft.ML.Data;
 
 namespace Microsoft.ML
@@ -69,9 +67,9 @@ public sealed class MLContext : IHostEnvironment
         public event EventHandler<LoggingEventArgs> Log;
 
         /// <summary>
-        /// This is a MEF composition container catalog to be used for model loading.
+        /// This is a catalog of components that will be used for model loading.
         /// </summary>
-        public CompositionContainer CompositionContainer { get; set; }
+        public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;
 
         /// <summary>
         /// Create the ML context.
@@ -80,7 +78,7 @@ public sealed class MLContext : IHostEnvironment
         /// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
         public MLContext(int? seed = null, int conc = 0)
         {
-            _env = new LocalEnvironment(seed, conc, MakeCompositionContainer);
+            _env = new LocalEnvironment(seed, conc);
             _env.AddListener(ProcessMessage);
 
             BinaryClassification = new BinaryClassificationCatalog(_env);
@@ -94,18 +92,6 @@ public MLContext(int? seed = null, int conc = 0)
             Data = new DataOperationsCatalog(_env);
         }
 
-        private CompositionContainer MakeCompositionContainer()
-        {
-            if (CompositionContainer == null)
-                return null;
-
-            var mlContext = CompositionContainer.GetExportedValueOrDefault<MLContext>();
-            if (mlContext == null)
-                CompositionContainer.ComposeExportedValue<MLContext>(this);
-
-            return CompositionContainer;
-        }
-
         private void ProcessMessage(IMessageSource source, ChannelMessage message)
         {
             var log = Log;
@@ -120,7 +106,6 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
 
         int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
         bool IHostEnvironment.IsCancelled => _env.IsCancelled;
-        ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
         string IExceptionContext.ContextDescription => _env.ContextDescription;
         IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
         TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
@@ -128,6 +113,5 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
         IChannel IChannelProvider.Start(string name) => _env.Start(name);
         IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
         IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
-        CompositionContainer IHostEnvironment.GetCompositionContainer() => _env.GetCompositionContainer();
     }
 }
diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
index a9e3a4a80e..b17b2cf39f 100644
--- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
+++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
@@ -3,7 +3,6 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
-using System.ComponentModel.Composition.Hosting;
 
 namespace Microsoft.ML.Data
 {
@@ -14,8 +13,6 @@ namespace Microsoft.ML.Data
     /// </summary>
     internal sealed class LocalEnvironment : HostEnvironmentBase<LocalEnvironment>
     {
-        private readonly Func<CompositionContainer> _compositionContainerFactory;
-
         private sealed class Channel : ChannelBase
         {
             public readonly Stopwatch Watch;
@@ -49,11 +46,9 @@ protected override void Dispose(bool disposing)
         /// </summary>
         /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
         /// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
-        /// <param name="compositionContainerFactory">The function to retrieve the composition container</param>
-        public LocalEnvironment(int? seed = null, int conc = 0, Func<CompositionContainer> compositionContainerFactory = null)
+        public LocalEnvironment(int? seed = null, int conc = 0)
             : base(RandomUtils.Create(seed), verbose: false, conc)
         {
-            _compositionContainerFactory = compositionContainerFactory;
         }
 
         /// <summary>
@@ -96,13 +91,6 @@ protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase pare
             return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
         }
 
-        public override CompositionContainer GetCompositionContainer()
-        {
-            if (_compositionContainerFactory != null)
-                return _compositionContainerFactory();
-            return base.GetCompositionContainer();
-        }
-
         private sealed class Host : HostBase
         {
             public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose, int? conc)
diff --git a/src/Microsoft.ML.Transforms/CustomMappingFactory.cs b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs
new file mode 100644
index 0000000000..f1778f72f7
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/CustomMappingFactory.cs
@@ -0,0 +1,47 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Data.DataView;
+
+namespace Microsoft.ML.Transforms
+{
+    /// <summary>
+    /// Place this attribute onto a type to cause it to be considered a custom mapping factory.
+    /// </summary>
+    [AttributeUsage(AttributeTargets.Class)]
+    public sealed class CustomMappingFactoryAttributeAttribute : ExtensionBaseAttribute
+    {
+        public CustomMappingFactoryAttributeAttribute(string contractName)
+            : base(contractName)
+        {
+        }
+    }
+
+    internal interface ICustomMappingFactory
+    {
+        ITransformer CreateTransformer(IHostEnvironment env, string contractName);
+    }
+
+    /// <summary>
+    /// The base type for custom mapping factories.
+    /// </summary>
+    /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the input <see cref="IDataView"/>.</typeparam>
+    /// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
+    public abstract class CustomMappingFactory<TSrc, TDst> : ICustomMappingFactory
+        where TSrc : class, new()
+        where TDst : class, new()
+    {
+        /// <summary>
+        /// Returns the mapping delegate that maps from <typeparamref name="TSrc"/> inputs to <typeparamref name="TDst"/> outputs.
+        /// </summary>
+        public abstract Action<TSrc, TDst> GetMapping();
+
+        ITransformer ICustomMappingFactory.CreateTransformer(IHostEnvironment env, string contractName)
+        {
+            Action<TSrc, TDst> mapAction = GetMapping();
+            return new CustomMappingTransformer<TSrc, TDst>(env, mapAction, contractName);
+        }
+    }
+}
diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs
index 2f847797f7..6950d9f46a 100644
--- a/src/Microsoft.ML.Transforms/LambdaTransform.cs
+++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs
@@ -68,11 +68,13 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
 
             var contractName = ctx.LoadString();
 
-            var composition = env.GetCompositionContainer();
-            if (composition == null)
-                throw Contracts.Except("Unable to get the MEF composition container");
-            ITransformer transformer = composition.GetExportedValue<ITransformer>(contractName);
-            return transformer;
+            object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName);
+            if (!(factoryObject is ICustomMappingFactory mappingFactory))
+            {
+                throw env.Except($"The class with contract '{contractName}' must derive from '{typeof(CustomMappingFactory<,>).FullName}'.");
+            }
+
+            return mappingFactory.CreateTransformer(env, contractName);
         }
 
         /// <summary>
diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
index 4cf6548d6c..216f2c3179 100644
--- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
+++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs
@@ -38,7 +38,6 @@
 #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll" 
 #r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll" 
 #r "System" 
-#r "System.ComponentModel.Composition" 
 #r "System.Core" 
 #r "System.Xml.Linq" 
 
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
index ce9a9e0385..cf37e1bb70 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
@@ -4,8 +4,6 @@
 
 using System;
 using System.Collections.Generic;
-using System.ComponentModel.Composition;
-using System.ComponentModel.Composition.Hosting;
 using System.IO;
 using System.Linq;
 using Microsoft.Data.DataView;
@@ -13,6 +11,7 @@
 using Microsoft.ML.RunTests;
 using Microsoft.ML.TestFramework;
 using Microsoft.ML.Trainers;
+using Microsoft.ML.Transforms;
 using Microsoft.ML.Transforms.Categorical;
 using Microsoft.ML.Transforms.Normalizers;
 using Microsoft.ML.Transforms.Text;
@@ -491,22 +490,22 @@ public void CustomTransformer()
         }
 
         /// <summary>
-        /// One class that contains all custom mappings that we need for our model.
+        /// One class that contains the custom mapping functionality that we need for our model.
+        /// 
+        /// It has a <see cref="CustomMappingFactoryAttributeAttribute"/> on it and
+        /// derives from <see cref="CustomMappingFactory{TSrc, TDst}"/>.
         /// </summary>
-        public class CustomMappings
+        [CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
+        public class CustomMappings : CustomMappingFactory<InputRow, OutputRow>
         {
             // This is the custom mapping. We now separate it into a method, so that we can use it both in training and in loading.
             public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;
 
-            // MLContext is needed to create a new transformer. We are using 'Import' to have ML.NET populate
-            // this property.
-            [Import]
-            public MLContext MLContext { get; set; }
-
-            // We are exporting the custom transformer by the name 'IncomeMapping'.
-            [Export(nameof(IncomeMapping))]
-            public ITransformer MyCustomTransformer 
-                => MLContext.Transforms.CustomMappingTransformer<InputRow, OutputRow>(IncomeMapping, nameof(IncomeMapping));
+            // This factory method will be called when loading the model to get the mapping operation.
+            public override Action<InputRow, OutputRow> GetMapping()
+            {
+                return IncomeMapping;
+            }
         }
 
         private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string modelPath)
@@ -530,8 +529,9 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string
             // Now pretend we are in a different process.
             var newContext = new MLContext();
 
-            // Create a custom composition container for all our custom mapping actions.
-            newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(CustomMappings)));
+            // Register the assembly that contains 'CustomMappings' with the ComponentCatalog
+            // so it can be found when loading the model.
+            newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
 
             // Now we can load the model.
             ITransformer loadedModel;
diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
index 9a1c048ebf..3dcfcc991a 100644
--- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
@@ -3,8 +3,6 @@
 // See the LICENSE file in the project root for more information.
 
 using System;
-using System.ComponentModel.Composition;
-using System.ComponentModel.Composition.Hosting;
 using System.Linq;
 using Microsoft.Data.DataView;
 using Microsoft.ML.Data;
@@ -32,18 +30,18 @@ public class MyOutput
             public string Together { get; set; }
         }
 
-        public class MyLambda
+        [CustomMappingFactoryAttribute("MyLambda")]
+        public class MyLambda : CustomMappingFactory<MyInput, MyOutput>
         {
-            [Export("MyLambda")]
-            public ITransformer MyTransformer => ML.Transforms.CustomMappingTransformer<MyInput, MyOutput>(MyAction, "MyLambda");
-
-            [Import]
-            public MLContext ML { get; set; }
-
             public static void MyAction(MyInput input, MyOutput output)
             {
                 output.Together = $"{input.Float1} + {string.Join(", ", input.Float4)}";
             }
+
+            public override Action<MyInput, MyOutput> GetMapping()
+            {
+                return MyAction;
+            }
         }
 
         [Fact]
@@ -67,14 +65,14 @@ public void TestCustomTransformer()
             try
             {
                 TestEstimatorCore(customEst, data);
-                Assert.True(false, "Cannot work without MEF injection");
+                Assert.True(false, "Cannot work without RegisterAssembly");
             }
             catch (InvalidOperationException ex)
             {
                 if (!ex.IsMarked())
                     throw;
             }
-            ML.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(MyLambda)));
+            ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly);
             TestEstimatorCore(customEst, data);
             transformedData = customEst.Fit(data).Transform(data);