diff --git a/doc/source/categorical.rst b/doc/source/categorical.rst
index 721e032b8bb92..ff37fbbb4aa24 100644
--- a/doc/source/categorical.rst
+++ b/doc/source/categorical.rst
@@ -1145,7 +1145,8 @@ dtype in apply
 
 Pandas currently does not preserve the dtype in apply functions: If you apply along rows you get
 a `Series` of ``object`` `dtype` (same as getting a row -> getting one element will return a
-basic type) and applying along columns will also convert to object.
+basic type) and applying along columns will also convert to object. ``NaN`` values are unaffected.
+You can use ``fillna`` to handle missing values before applying a function.
 
 .. ipython:: python
 
diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst
index 6095865fde87c..933d6a486ad07 100644
--- a/doc/source/whatsnew/v0.24.0.rst
+++ b/doc/source/whatsnew/v0.24.0.rst
@@ -1273,6 +1273,7 @@ Categorical
 - Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
 - Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`)
 - Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`)
+- Bug in :meth:`Categorical.apply` where ``NaN`` values could be handled unpredictably. They now remain unchanged (:issue:`24241`)
 
 Datetimelike
 ^^^^^^^^^^^^
diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py
index 6ccb8dc5d2725..9a8b345cea1b3 100644
--- a/pandas/core/arrays/categorical.py
+++ b/pandas/core/arrays/categorical.py
@@ -1166,7 +1166,7 @@ def map(self, mapper):
         Maps the categories to new categories. If the mapping correspondence is
         one-to-one the result is a :class:`~pandas.Categorical` which has the
         same order property as the original, otherwise a :class:`~pandas.Index`
-        is returned.
+        is returned. NaN values are unaffected.
 
         If a `dict` or :class:`~pandas.Series` is used any unmapped category is
         mapped to `NaN`. Note that if this happens an :class:`~pandas.Index`
@@ -1234,6 +1234,11 @@ def map(self, mapper):
                                    categories=new_categories,
                                    ordered=self.ordered)
         except ValueError:
+            # NA values are represented in self._codes with -1
+            # np.take causes NA values to take final element in new_categories
+            if np.any(self._codes == -1):
+                new_categories = new_categories.insert(len(new_categories),
+                                                       np.nan)
             return np.take(new_categories, self._codes)
 
     __eq__ = _cat_compare_op('__eq__')
diff --git a/pandas/tests/indexes/test_category.py b/pandas/tests/indexes/test_category.py
index bb537f30821e4..d9dfeadd10b84 100644
--- a/pandas/tests/indexes/test_category.py
+++ b/pandas/tests/indexes/test_category.py
@@ -311,6 +311,29 @@ def test_map_with_categorical_series(self):
         exp = pd.Index(["odd", "even", "odd", np.nan])
         tm.assert_index_equal(a.map(c), exp)
 
+    @pytest.mark.parametrize(
+        (
+            'data',
+            'f'
+        ),
+        (
+            ([1, 1, np.nan], pd.isna),
+            ([1, 2, np.nan], pd.isna),
+            ([1, 1, np.nan], {1: False}),
+            ([1, 2, np.nan], {1: False, 2: False}),
+            ([1, 1, np.nan], pd.Series([False, False])),
+            ([1, 2, np.nan], pd.Series([False, False, False]))
+        ))
+    def test_map_with_nan(self, data, f):  # GH 24241
+        values = pd.Categorical(data)
+        result = values.map(f)
+        if data[1] == 1:
+            expected = pd.Categorical([False, False, np.nan])
+            tm.assert_categorical_equal(result, expected)
+        else:
+            expected = pd.Index([False, False, np.nan])
+            tm.assert_index_equal(result, expected)
+
     @pytest.mark.parametrize('klass', [list, tuple, np.array, pd.Series])
     def test_where(self, klass):
         i = self.create_index()