diff --git a/requirements-dev.txt b/requirements-dev.txt index c2de4e7..0a0a1cf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,3 +6,4 @@ docformatter interrogate numpy pandas +pydantic diff --git a/skllm/llm/base.py b/skllm/llm/base.py index 18b7edf..3a29cf5 100644 --- a/skllm/llm/base.py +++ b/skllm/llm/base.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypeVar, Type +from pydantic import BaseModel +T = TypeVar('T', bound=BaseModel) class BaseTextCompletionMixin(ABC): @abstractmethod @@ -13,6 +15,11 @@ def _convert_completion_to_str(self, completion: Any): """Converts a completion object to a string""" pass + @abstractmethod + def _get_parsed_completion(self, output_model: Type[T], **kwargs) -> T: + """Gets a chat completion parsed into the specified Pydantic model""" + pass + class BaseClassifierMixin(BaseTextCompletionMixin): @abstractmethod diff --git a/skllm/llm/gpt/clients/openai/completion.py b/skllm/llm/gpt/clients/openai/completion.py index 779c0f2..d9c80c9 100644 --- a/skllm/llm/gpt/clients/openai/completion.py +++ b/skllm/llm/gpt/clients/openai/completion.py @@ -1,4 +1,5 @@ -import openai +from typing import TypeVar, Type +from pydantic import BaseModel from openai import OpenAI from skllm.llm.gpt.clients.openai.credentials import ( set_azure_credentials, @@ -6,6 +7,8 @@ ) from skllm.utils import retry +T = TypeVar('T', bound=BaseModel) + @retry(max_retries=3) def get_chat_completion( @@ -50,3 +53,51 @@ def get_chat_completion( temperature=0.0, messages=messages, **model_dict ) return completion + + +@retry(max_retries=3) +def get_parsed_completion( + messages: dict, + output_model: Type[T], + key: str, + org: str, + model: str = "gpt-3.5-turbo", + api="openai", +) -> T: + """Gets a chat completion parsed into the specified Pydantic model. + + Parameters + ---------- + messages : dict + input messages to use. + output_model : Type[T] + Pydantic model class to parse the response into. + key : str + The OPEN AI key to use. + org : str + The OPEN AI organization ID to use. + model : str, optional + The OPEN AI model to use. Defaults to "gpt-3.5-turbo". + api : str + The API to use. Must be one of "openai" or "azure". Defaults to "openai". + + Returns + ------- + parsed_model : T + Instance of the specified Pydantic model + """ + if api in ("openai", "custom_url"): + client = set_credentials(key, org) + elif api == "azure": + client = set_azure_credentials(key, org) + else: + raise ValueError("Invalid API") + + completion = client.beta.chat.completions.parse( + model=model, + messages=messages, + response_format=output_model, + temperature=0.0 + ) + + return completion.choices[0].message.parsed diff --git a/tests/test_structured_outputs.py b/tests/test_structured_outputs.py new file mode 100644 index 0000000..f78d3dc --- /dev/null +++ b/tests/test_structured_outputs.py @@ -0,0 +1,87 @@ + +import unittest +from pydantic import BaseModel +from skllm.llm.gpt.clients.openai.completion import get_parsed_completion +import unittest +from unittest.mock import patch +from types import SimpleNamespace +import skllm.llm.gpt.clients.openai.completion as completion_mod + +class DummyCompletions: + def __init__(self, model_cls): + self._model_cls = model_cls + + def parse(self, *, model, messages, response_format, temperature): + # response_format is the Pydantic model class (TestEvent) + fake = self._model_cls( + event_name="science fair", + date="Friday", + attendees=["Alice", "Bob"], + ) + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(parsed=fake))] + ) + + def create(self, *, temperature, messages, **kwargs): + # if you ever test get_chat_completion + return {"id": "dummy", "choices": []} + +class DummyClient: + def __init__(self, model_cls): + self.chat = SimpleNamespace(completions=DummyCompletions(model_cls)) + self.beta = SimpleNamespace(chat=SimpleNamespace(completions=DummyCompletions(model_cls))) + +class OpenAITestCase(unittest.TestCase): + def setUp(self): + self.patcher1 = patch.object( + completion_mod, + "set_credentials", + lambda key, org: DummyClient(TestEvent) + ) + self.patcher2 = patch.object( + completion_mod, + "set_azure_credentials", + lambda key, org: DummyClient(TestEvent) + ) + self.patcher1.start() + self.patcher2.start() + + def tearDown(self): + self.patcher1.stop() + self.patcher2.stop() + + +class TestEvent(BaseModel): + event_name: str + date: str + attendees: list[str] + +class TestOpenAIStructuredOutput(OpenAITestCase): + def test_openai_structured_output(self): + """Test that structured outputs are properly parsed into Pydantic models.""" + messages = [ + {"role": "system", "content": "Extract event information in JSON format"}, + {"role": "user", "content": "Alice and Bob are attending the science fair on Friday"} + ] + + # Test successful parsing + result = get_parsed_completion( + messages=messages, + output_model=TestEvent, + key="dummy_value", # Replace with actual key + org="dummy_value", # Replace with actual org + model="gpt-4o-mini" + ) + + # Validate the result structure + self.assertIsInstance(result, TestEvent) + self.assertIsInstance(result.event_name, str) + self.assertGreater(len(result.event_name), 0) + self.assertIsInstance(result.date, str) + self.assertGreater(len(result.date), 0) + self.assertIsInstance(result.attendees, list) + self.assertGreaterEqual(len(result.attendees), 2) # Should have at least Alice and Bob + self.assertTrue(all(isinstance(name, str) for name in result.attendees)) + +if __name__ == '__main__': + unittest.main()