From a18b92d4565b09ded859ffaac1c6fcd925488e61 Mon Sep 17 00:00:00 2001 From: Paige Bailey Date: Sun, 8 Sep 2024 21:40:11 -0700 Subject: [PATCH] Updated to latest recommended model (Gemini 1.5 Flash). --- skllm/llm/vertex/mixin.py | 2 +- skllm/models/vertex/classification/tunable.py | 4 ++-- skllm/models/vertex/classification/zero_shot.py | 8 ++++---- skllm/models/vertex/text2text/tunable.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/skllm/llm/vertex/mixin.py b/skllm/llm/vertex/mixin.py index ed90940..78d7080 100644 --- a/skllm/llm/vertex/mixin.py +++ b/skllm/llm/vertex/mixin.py @@ -70,7 +70,7 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]: class VertexTunableMixin(BaseTunableMixin): - _supported_tunable_models = ["text-bison@002"] + _supported_tunable_models = ["gemini-1.5-flash"] def _set_hyperparameters(self, base_model: str, n_update_steps: int, **kwargs): self.verify_model_is_supported(base_model) diff --git a/skllm/models/vertex/classification/tunable.py b/skllm/models/vertex/classification/tunable.py index c14d88f..451b1b7 100644 --- a/skllm/models/vertex/classification/tunable.py +++ b/skllm/models/vertex/classification/tunable.py @@ -19,7 +19,7 @@ class _TunableClassifier( class VertexClassifier(_TunableClassifier, _SingleLabelMixin): def __init__( self, - base_model: str = "text-bison@002", + base_model: str = "gemini-1.5-flash", n_update_steps: int = 1, default_label: str = "Random", ): @@ -29,7 +29,7 @@ def __init__( Parameters ---------- base_model : str, optional - base model to use, by default "text-bison@002" + base model to use, by default "gemini-1.5-flash" n_update_steps : int, optional number of epochs, by default 1 default_label : str, optional diff --git a/skllm/models/vertex/classification/zero_shot.py b/skllm/models/vertex/classification/zero_shot.py index ea84d27..49e5f74 100644 --- a/skllm/models/vertex/classification/zero_shot.py +++ b/skllm/models/vertex/classification/zero_shot.py @@ -12,7 +12,7 @@ class ZeroShotVertexClassifier( ): def __init__( self, - model: str = "text-bison@002", + model: str = "gemini-1.5-flash", default_label: str = "Random", prompt_template: Optional[str] = None, **kwargs, @@ -23,7 +23,7 @@ def __init__( Parameters ---------- model : str, optional - model to use, by default "text-bison@002" + model to use, by default "gemini-1.5-flash" default_label : str, optional default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" prompt_template : Optional[str], optional @@ -42,7 +42,7 @@ class MultiLabelZeroShotVertexClassifier( ): def __init__( self, - model: str = "text-bison@002", + model: str = "gemini-1.5-flash", default_label: str = "Random", prompt_template: Optional[str] = None, max_labels: Optional[int] = 5, @@ -54,7 +54,7 @@ def __init__( Parameters ---------- model : str, optional - model to use, by default "text-bison@002" + model to use, by default "gemini-1.5-flash" default_label : str, optional default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" prompt_template : Optional[str], optional diff --git a/skllm/models/vertex/text2text/tunable.py b/skllm/models/vertex/text2text/tunable.py index 9abcb55..c266c49 100644 --- a/skllm/models/vertex/text2text/tunable.py +++ b/skllm/models/vertex/text2text/tunable.py @@ -12,7 +12,7 @@ class TunableVertexText2Text( ): def __init__( self, - base_model: str = "text-bison@002", + base_model: str = "gemini-1.5-flash", n_update_steps: int = 1, ): """ @@ -21,7 +21,7 @@ def __init__( Parameters ---------- base_model : str, optional - base model to use, by default "text-bison@002" + base model to use, by default "gemini-1.5-flash" n_update_steps : int, optional number of epochs, by default 1 """