Skip to content

GoodnewsRanker and more unit tests #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ wheels/

# .env files
.env

# data
data
78 changes: 78 additions & 0 deletions src/mcp_goodnews/goodnews_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
import os
from typing import Any

from cohere import AsyncClientV2
from cohere.types import ChatMessages, ChatResponse

from mcp_goodnews.newsapi import Article

# prompt templates
DEFAULT_GOODNEWS_SYSTEM_PROMPT = "Given the list of articles, return {num_articles_to_return} of the most positive articles."
DEFAULT_RANK_INSTRUCTION_TEMPLATE = (
"Please rank the articles provided below according to their positivity "
"based on their `title` as well as the `content` fields of an article."
"\n\n<articles>\n\n{formatted_articles}</articles>"
)

DEFAULT_NUM_ARTICLES_TO_RETURN = 3
DEFAULT_MODEL_NAME = "command-r-plus-08-2024"


class GoodnewsRanker:
def __init__(
self,
model_name: str = DEFAULT_MODEL_NAME,
num_articles_to_return: int = 3,
system_prompt_template: str = DEFAULT_GOODNEWS_SYSTEM_PROMPT,
rank_instruction_template: str = DEFAULT_RANK_INSTRUCTION_TEMPLATE,
):
self.model_name = model_name
self.num_articles_to_return = num_articles_to_return
self.system_prompt_template = system_prompt_template
self.rank_instruction_template = rank_instruction_template

def _get_client(self) -> AsyncClientV2:
"""Get cohere async client.

NOTE: this requires `COHERE_API_KEY` env variable to be set.
"""
return AsyncClientV2(
api_key=os.environ.get("COHERE_API_KEY"),
)

def _format_articles(self, articles: list[Article]) -> str:
return "\n\n".join(
json.dumps(a.model_dump(by_alias=True), indent=4) for a in articles
)

def _prepare_chat_messages(
self, articles: list[Article]
) -> list[ChatMessages]:
messages = [
{
"role": "system",
"content": self.system_prompt_template.format(
num_articles_to_return=self.num_articles_to_return
),
},
{
"role": "user",
"content": self.rank_instruction_template.format(
formatted_articles=self._format_articles(articles)
),
},
]
return messages

def _postprocess_chat_response(self, response: ChatResponse) -> str | Any:
return response.message.content[0].text

async def rank_articles(self, articles: list[Article]) -> str:
"""Uses cohere llms to rank a set of articles."""
co = self._get_client()
response: ChatResponse = await co.chat(
model=self.model_name,
messages=self._prepare_chat_messages(articles),
)
return self._postprocess_chat_response(response)
20 changes: 0 additions & 20 deletions src/mcp_goodnews/llm.py

This file was deleted.

9 changes: 3 additions & 6 deletions src/mcp_goodnews/newsapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@


class ArticleSource(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id_: str = Field(alias="id")
name: str


class Article(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
)
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)

source: ArticleSource
author: str
Expand All @@ -28,9 +27,7 @@ class Article(BaseModel):


class NewsAPIResponse(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
)
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)

status: str
total_results: int
Expand Down
Empty file.
29 changes: 29 additions & 0 deletions tests/test_goodnews_ranker/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from mcp_goodnews.newsapi import Article, ArticleSource


@pytest.fixture()
def example_articles() -> list[Article]:
return [
Article(
source=ArticleSource(id="1", name="source 1"),
author="fake author 1",
title="fake title 1",
description="fake description 1",
url="fake url 1",
url_to_image="fake url to image 1",
published_at="fake published at 1",
content="fake content 1",
),
Article(
source=ArticleSource(id="2", name="source 2"),
author="fake author 2",
title="fake title 2",
description="fake description 2",
url="fake url 2",
url_to_image="fake url to image 2",
published_at="fake published at 2",
content="fake content 2",
),
]
109 changes: 109 additions & 0 deletions tests/test_goodnews_ranker/test_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from cohere.types import (
AssistantMessageResponse,
ChatResponse,
TextAssistantMessageResponseContentItem,
)

from mcp_goodnews.goodnews_ranker import GoodnewsRanker
from mcp_goodnews.newsapi import Article


def test_goodnews_ranker_init() -> None:
ranker = GoodnewsRanker(
model_name="fake model name",
num_articles_to_return=5,
system_prompt_template="fake template {num_articles_to_return}",
rank_instruction_template="fake rank template {formatted_articles}",
)

assert ranker.model_name == "fake model name"
assert ranker.num_articles_to_return == 5
assert (
ranker.system_prompt_template.format(num_articles_to_return="alice")
== "fake template alice"
)
assert (
ranker.rank_instruction_template.format(
formatted_articles="fake article"
)
== "fake rank template fake article"
)


@patch("mcp_goodnews.goodnews_ranker.AsyncClientV2")
def test_goodnews_get_client(
mock_async_client_v2: MagicMock,
) -> None:
ranker = GoodnewsRanker()

# act
with patch.dict("os.environ", {"COHERE_API_KEY": "fake-key"}):
_ = ranker._get_client()

mock_async_client_v2.assert_called_once_with(api_key="fake-key")


def test_goodnews_format_articles(example_articles: list[Article]) -> None:
ranker = GoodnewsRanker()

# act
formated_str = ranker._format_articles(example_articles)

assert formated_str == "\n\n".join(
json.dumps(a.model_dump(by_alias=True), indent=4)
for a in example_articles
)


def test_goodnews_prepare_chat_messages(
example_articles: list[Article],
) -> None:
ranker = GoodnewsRanker()

# act
messages = ranker._prepare_chat_messages(example_articles)

# assert
assert len(messages) == 2 # system prompt and user prompt
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
assert ranker._format_articles(example_articles) in messages[1]["content"]


@pytest.mark.asyncio
@patch.object(GoodnewsRanker, "_postprocess_chat_response")
@patch.object(GoodnewsRanker, "_get_client")
async def test_rank_articles(
mock_get_client: MagicMock,
mock_postprocess_chat_response: MagicMock,
example_articles: list[Article],
) -> None:
# arrange mocks
fake_chat_response = ChatResponse(
id="1",
finish_reason="COMPLETE",
prompt=None,
message=AssistantMessageResponse(
content=[
TextAssistantMessageResponseContentItem(text="mock response")
]
),
)
mock_async_client = AsyncMock()
mock_async_client.chat.return_value = fake_chat_response
mock_get_client.return_value = mock_async_client
ranker = GoodnewsRanker()

# act
await ranker.rank_articles(example_articles)

# assert
mock_postprocess_chat_response.assert_called_once_with(fake_chat_response)
mock_async_client.chat.assert_called_once_with(
model=ranker.model_name,
messages=ranker._prepare_chat_messages(example_articles),
)
7 changes: 0 additions & 7 deletions tests/test_hello.py

This file was deleted.

21 changes: 21 additions & 0 deletions tests/test_newsapi_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,24 @@ def response_json() -> dict[str, Any]:
response_json = json.load(f)
response_json = cast(dict[str, Any], response_json)
return response_json # type: ignore[no-any-return]


@pytest.fixture
def example_source_dict() -> dict[str, Any]:
return {"id": "fake_source_id", "name": "fake_name"}


@pytest.fixture
def example_article_dict(
example_source_dict: dict[str, Any]
) -> dict[str, Any]:
return {
"source": example_source_dict,
"author": "fake author",
"title": "fake title",
"description": "fake description",
"url": "fake url",
"urlToImage": "fake url to image",
"publishedAt": "fake published at",
"content": "fake content",
}
19 changes: 18 additions & 1 deletion tests/test_newsapi_models/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
from typing import Any

from mcp_goodnews.newsapi import NewsAPIResponse
from mcp_goodnews.newsapi import Article, ArticleSource, NewsAPIResponse


def test_newsapiresponse_from_json(response_json: dict[str, Any]) -> None:
response = NewsAPIResponse.model_validate(response_json)

assert response.status == "ok"
assert response.total_results == 10
assert all(a.source.id_ == "bbcnews" for a in response.articles)


def test_article_source_serialization(
example_source_dict: dict[str, Any]
) -> None:
example_source = ArticleSource.model_validate(example_source_dict)
serialized = example_source.model_dump(by_alias=True)

assert serialized == example_source_dict


def test_article_serialization(example_article_dict: dict[str, Any]) -> None:
example_article = Article.model_validate(example_article_dict)
serialized = example_article.model_dump(by_alias=True)

assert serialized == example_article_dict