diff --git a/docs/examples/plot_warm_start.py b/docs/examples/plot_warm_start.py new file mode 100644 index 00000000..4b45a216 --- /dev/null +++ b/docs/examples/plot_warm_start.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +""" +=========================== +Fitting additional nodes +=========================== + +The local classifiers per node and per parent node support `warm_start=True`, +which allows to add more estimators to an already fitted model. + +Nodes that were trained previously are skipped, and only new nodes are fitted. + +.. tabs:: + + .. code-tab:: python + :caption: LocalClassifierPerNode + + rf = RandomForestClassifier() + classifier = LocalClassifierPerNode(local_classifier=rf, warm_start=True) + classifier.fit(X, y) + + + .. code-tab:: python + :caption: LocalClassifierPerParentNode + + rf = RandomForestClassifier() + classifier = LocalClassifierPerParentNode(local_classifier=rf, warm_start=True) + classifier.fit(X, y) + +In the code below, there is a working example with the local classifier per parent node. +However, the code can be easily updated by replacing lines 21-22 with the example shown in the tabs above. + +""" +from sklearn.linear_model import LogisticRegression + +from hiclass import LocalClassifierPerParentNode + +# Define data +X_1 = [[1], [2]] +X_2 = [[3], [4], [5]] +Y_1 = [ + ["Animal", "Mammal", "Sheep"], + ["Animal", "Mammal", "Cow"], +] +Y_2 = [ + ["Animal", "Reptile", "Snake"], + ["Animal", "Reptile", "Lizard"], + ["Animal", "Mammal", "Cow"], +] +X_test = [[5], [4], [3], [2], [1]] + +# Use logistic regression classifiers for every parent node +# And warm_start=True to allow training with more data in the future. +lr = LogisticRegression() +classifier = LocalClassifierPerParentNode(local_classifier=lr, warm_start=True) + +# Train local classifier per parent node +classifier.fit(X_1, Y_1) + +# Fit additional data +classifier.fit(X_2, Y_2) + +# Predict +predictions = classifier.predict(X_test) +print(predictions) diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 2bb5580b..524abf95 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -65,6 +65,7 @@ def __init__( edge_list: str = None, replace_classifiers: bool = True, n_jobs: int = 1, + warm_start: bool = False, classifier_abbreviation: str = "", ): """ @@ -87,6 +88,8 @@ def __init__( n_jobs : int, default=1 The number of jobs to run in parallel. Only :code:`fit` is parallelized. If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`. + warm_start : bool, default=False + When set to `True`, reuse the solution of the previous call to fit and add more estimators for new nodes, otherwise, just fit a whole new DAG. See :ref:`Fitting additional nodes` for more information. classifier_abbreviation : str, default="" The abbreviation of the local hierarchical classifier to be displayed during logging. """ @@ -95,6 +98,7 @@ def __init__( self.edge_list = edge_list self.replace_classifiers = replace_classifiers self.n_jobs = n_jobs + self.warm_start = warm_start self.classifier_abbreviation = classifier_abbreviation def fit(self, X, y, sample_weight=None): diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 21160c28..03fba13a 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -42,6 +42,7 @@ def __init__( edge_list: str = None, replace_classifiers: bool = True, n_jobs: int = 1, + warm_start: bool = False, ): """ Initialize a local classifier per node. @@ -74,6 +75,8 @@ def __init__( n_jobs : int, default=1 The number of jobs to run in parallel. Only :code:`fit` is parallelized. If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`. + warm_start : bool, default=False + When set to `True`, reuse the solution of the previous call to fit and add more estimators for new nodes, otherwise, just fit a whole new DAG. See :ref:`Fitting additional nodes` for more information. """ super().__init__( local_classifier=local_classifier, @@ -81,6 +84,7 @@ def __init__( edge_list=edge_list, replace_classifiers=replace_classifiers, n_jobs=n_jobs, + warm_start=warm_start, classifier_abbreviation="LCPN", ) self.binary_policy = binary_policy diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index fc9c4583..6e4ba085 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -40,6 +40,7 @@ def __init__( edge_list: str = None, replace_classifiers: bool = True, n_jobs: int = 1, + warm_start: bool = False, ): """ Initialize a local classifier per parent node. @@ -61,6 +62,8 @@ def __init__( n_jobs : int, default=1 The number of jobs to run in parallel. Only :code:`fit` is parallelized. If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`. + warm_start : bool, default=False + When set to `True`, reuse the solution of the previous call to fit and add more estimators for new nodes, otherwise, just fit a whole new DAG. See :ref:`Fitting additional nodes` for more information. """ super().__init__( local_classifier=local_classifier, @@ -68,6 +71,7 @@ def __init__( edge_list=edge_list, replace_classifiers=replace_classifiers, n_jobs=n_jobs, + warm_start=warm_start, classifier_abbreviation="LCPPN", )