Skip to content

Conversation

jwm4
Copy link
Contributor

@jwm4 jwm4 commented Oct 3, 2025

What does this PR do?

  • The watsonx.ai provider now uses the LiteLLM mixin instead of using IBM's library, which does not seem to be working (see watsonx.ai inference provider is not working for me #3165 for context).
  • The watsonx.ai provider now lists all the models available by calling the watsonx.ai server instead of having a hard coded list of known models. (That list gets out of date quickly)
  • An edge case in llama_stack/core/routers/inference.py is addressed that was causing my manual tests to fail.
  • Fixes b64_encode_openai_embeddings_response which was trying to enumerate over a dictionary and then reference elements of the dictionary using .field instead of ["field"]. That method is called by the LiteLLM mixin for embedding models, so it is needed to get the watsonx.ai embedding models to work.
  • A unit test along the lines of the one in chore: update the groq inference impl to use openai-python for openai-compat functions #3348 is added. A more comprehensive plan for automatically testing the end-to-end functionality for inference providers would be a good idea, but is out of scope for this PR.
  • Updates to the watsonx distribution. Some were in response to the switch to LiteLLM (e.g., updating the Python packages needed). Others seem to be things that were already broken that I found along the way (e.g., a reference to a watsonx specific doc template that doesn't seem to exist).

Closes #3165

Also it is related to a line-item in #3387 but doesn't really address that goal (because it uses the LiteLLM mixin, not the OpenAI one). I tried the OpenAI one and it doesn't work with watsonx.ai, presumably because the watsonx.ai service is not OpenAI compatible. It works with LiteLLM because LiteLLM has a provider implementation for watsonx.ai.

Test Plan

The test script below goes back and forth between the OpenAI and watsonx providers. The idea is that the OpenAI provider shows how it should work and then the watsonx provider output shows that it is also working with watsonx. Note that the result from the MCP test is not as good (the Llama 3.3 70b model does not choose tools as wisely as gpt-4o), but it is still working and providing a valid response. For more details on setup and the MCP server being used for testing, see the AI Alliance sample notebook that these examples are drawn from.

#!/usr/bin/env python3

import json
from llama_stack_client import LlamaStackClient
from litellm import completion
import http.client


def print_response(response):
    """Print response in a nicely formatted way"""
    print(f"ID: {response.id}")
    print(f"Status: {response.status}")
    print(f"Model: {response.model}")
    print(f"Created at: {response.created_at}")
    print(f"Output items: {len(response.output)}")
    
    for i, output_item in enumerate(response.output):
        if len(response.output) > 1:
            print(f"\n--- Output Item {i+1} ---")
        print(f"Output type: {output_item.type}")
        
        if output_item.type in ("text", "message"):
            print(f"Response content: {output_item.content[0].text}")
        elif output_item.type == "file_search_call":
            print(f"  Tool Call ID: {output_item.id}")
            print(f"  Tool Status: {output_item.status}")
            # 'queries' is a list, so we join it for clean printing
            print(f"  Queries: {', '.join(output_item.queries)}")
            # Display results if they exist, otherwise note they are empty
            print(f"  Results: {output_item.results if output_item.results else 'None'}")
        elif output_item.type == "mcp_list_tools":
            print_mcp_list_tools(output_item)
        elif output_item.type == "mcp_call":
            print_mcp_call(output_item)
        else:
            print(f"Response content: {output_item.content}")


def print_mcp_call(mcp_call):
    """Print MCP call in a nicely formatted way"""
    print(f"\n🛠️  MCP Tool Call: {mcp_call.name}")
    print(f"   Server: {mcp_call.server_label}")
    print(f"   ID: {mcp_call.id}")
    print(f"   Arguments: {mcp_call.arguments}")
    
    if mcp_call.error:
        print("Error: {mcp_call.error}")
    elif mcp_call.output:
        print("Output:")
        # Try to format JSON output nicely
        try:
            parsed_output = json.loads(mcp_call.output)
            print(json.dumps(parsed_output, indent=4))
        except:
            # If not valid JSON, print as-is
            print(f"   {mcp_call.output}")
    else:
        print("   ⏳ No output yet")


def print_mcp_list_tools(mcp_list_tools):
    """Print MCP list tools in a nicely formatted way"""
    print(f"\n🔧 MCP Server: {mcp_list_tools.server_label}")
    print(f"   ID: {mcp_list_tools.id}")
    print(f"   Available Tools: {len(mcp_list_tools.tools)}")
    print("=" * 80)
    
    for i, tool in enumerate(mcp_list_tools.tools, 1):
        print(f"\n{i}. {tool.name}")
        print(f"   Description: {tool.description}")
        
        # Parse and display input schema
        schema = tool.input_schema
        if schema and 'properties' in schema:
            properties = schema['properties']
            required = schema.get('required', [])
            
            print("   Parameters:")
            for param_name, param_info in properties.items():
                param_type = param_info.get('type', 'unknown')
                param_desc = param_info.get('description', 'No description')
                required_marker = " (required)" if param_name in required else " (optional)"
                print(f"     • {param_name} ({param_type}){required_marker}")
                if param_desc:
                    print(f"       {param_desc}")
        
        if i < len(mcp_list_tools.tools):
            print("-" * 40)


def main():
    """Main function to run all the tests"""
    
    # Configuration
    LLAMA_STACK_URL = "http://localhost:8321/"
    LLAMA_STACK_MODEL_IDS = [
        "openai/gpt-3.5-turbo",
        "openai/gpt-4o",
        "llama-openai-compat/Llama-3.3-70B-Instruct",
        "watsonx/meta-llama/llama-3-3-70b-instruct"
    ]
    
    # Using gpt-4o for this demo, but feel free to try one of the others or add more to run.yaml.
    OPENAI_MODEL_ID = LLAMA_STACK_MODEL_IDS[1]
    WATSONX_MODEL_ID = LLAMA_STACK_MODEL_IDS[-1]
    NPS_MCP_URL = "http://localhost:3005/sse/"
    
    print("=== Llama Stack Testing Script ===")
    print(f"Using OpenAI model: {OPENAI_MODEL_ID}")
    print(f"Using WatsonX model: {WATSONX_MODEL_ID}")
    print(f"MCP URL: {NPS_MCP_URL}")
    print()
    
    # Initialize client
    print("Initializing LlamaStackClient...")
    client = LlamaStackClient(base_url="http://localhost:8321")
    
    # Test 1: List models
    print("\n=== Test 1: List Models ===")
    try:
        models = client.models.list()
        print(f"Found {len(models)} models")
    except Exception as e:
        print(f"Error listing models: {e}")
        raise e
    
    # Test 2: Basic chat completion with OpenAI
    print("\n=== Test 2: Basic Chat Completion (OpenAI) ===")
    try:
        chat_completion_response = client.chat.completions.create(
            model=OPENAI_MODEL_ID,
            messages=[{"role": "user", "content": "What is the capital of France?"}]
        )
        
        print("OpenAI Response:")
        for chunk in chat_completion_response.choices[0].message.content:
            print(chunk, end="", flush=True)
        print()
    except Exception as e:
        print(f"Error with OpenAI chat completion: {e}")
        raise e
    
    # Test 3: Basic chat completion with WatsonX
    print("\n=== Test 3: Basic Chat Completion (WatsonX) ===")
    try:
        chat_completion_response_wxai = client.chat.completions.create(
            model=WATSONX_MODEL_ID,
            messages=[{"role": "user", "content": "What is the capital of France?"}],
        )
        
        print("WatsonX Response:")
        for chunk in chat_completion_response_wxai.choices[0].message.content:
            print(chunk, end="", flush=True)
        print()
    except Exception as e:
        print(f"Error with WatsonX chat completion: {e}")
        raise e
    
    # Test 4: Tool calling with OpenAI
    print("\n=== Test 4: Tool Calling (OpenAI) ===")
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather for a specific location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g., San Francisco, CA",
                        },
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"]
                        },
                    },
                    "required": ["location"],
                },
            },
        }
    ]
    
    messages = [
        {"role": "user", "content": "What's the weather like in Boston, MA?"}
    ]
    
    try:
        print("--- Initial API Call ---")
        response = client.chat.completions.create(
            model=OPENAI_MODEL_ID,
            messages=messages,
            tools=tools,
            tool_choice="auto",  # "auto" is the default
        )
        print("OpenAI tool calling response received")
    except Exception as e:
        print(f"Error with OpenAI tool calling: {e}")
        raise e
    
    # Test 5: Tool calling with WatsonX
    print("\n=== Test 5: Tool Calling (WatsonX) ===")
    try:
        wxai_response = client.chat.completions.create(
            model=WATSONX_MODEL_ID,
            messages=messages,
            tools=tools,
            tool_choice="auto",  # "auto" is the default
        )
        print("WatsonX tool calling response received")
    except Exception as e:
        print(f"Error with WatsonX tool calling: {e}")
        raise e
    
    # Test 6: Streaming with WatsonX
    print("\n=== Test 6: Streaming Response (WatsonX) ===")
    try:
        chat_completion_response_wxai_stream = client.chat.completions.create(
            model=WATSONX_MODEL_ID,
            messages=[{"role": "user", "content": "What is the capital of France?"}],
            stream=True
        )
        print("Model response: ", end="")
        for chunk in chat_completion_response_wxai_stream:
            # Each 'chunk' is a ChatCompletionChunk object.
            # We want the content from the 'delta' attribute.
            if hasattr(chunk, 'choices') and chunk.choices is not None:
                content = chunk.choices[0].delta.content
                # The first few chunks might have None content, so we check for it.
                if content is not None:
                    print(content, end="", flush=True)
        print()
    except Exception as e:
        print(f"Error with streaming: {e}")
        raise e
    
    # Test 7: MCP with OpenAI
    print("\n=== Test 7: MCP Integration (OpenAI) ===")
    try:
        mcp_llama_stack_client_response = client.responses.create(
            model=OPENAI_MODEL_ID,
            input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.",
            tools=[
                {
                    "type": "mcp",
                    "server_url": NPS_MCP_URL,
                    "server_label": "National Parks Service tools",
                    "allowed_tools": ["search_parks", "get_park_events"],
                }
            ]
        )
        print_response(mcp_llama_stack_client_response)
    except Exception as e:
        print(f"Error with MCP (OpenAI): {e}")
        raise e
    
    # Test 8: MCP with WatsonX
    print("\n=== Test 8: MCP Integration (WatsonX) ===")
    try:
        mcp_llama_stack_client_response = client.responses.create(
            model=WATSONX_MODEL_ID,
            input="What is the capital of France?"
        )
        print_response(mcp_llama_stack_client_response)
    except Exception as e:
        print(f"Error with MCP (WatsonX): {e}")
        raise e
    
    # Test 9: MCP with Llama 3.3
    print("\n=== Test 9: MCP Integration (Llama 3.3) ===")
    try:
        mcp_llama_stack_client_response = client.responses.create(
            model=WATSONX_MODEL_ID,
            input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.",
            tools=[
                {
                    "type": "mcp",
                    "server_url": NPS_MCP_URL,
                    "server_label": "National Parks Service tools",
                    "allowed_tools": ["search_parks", "get_park_events"],
                }
            ]
        )
        print_response(mcp_llama_stack_client_response)
    except Exception as e:
        print(f"Error with MCP (Llama 3.3): {e}")
        raise e
    
    # Test 10: Embeddings
    print("\n=== Test 10: Embeddings ===")
    try:
        conn = http.client.HTTPConnection("localhost:8321")
        payload = json.dumps({
            "model": "watsonx/ibm/granite-embedding-278m-multilingual",
            "input": "Hello, world!",
        })
        headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json'
        }
        conn.request("POST", "/v1/openai/v1/embeddings", payload, headers)
        res = conn.getresponse()
        data = res.read()
        print(data.decode("utf-8"))
    except Exception as e:
        print(f"Error with Embeddings: {e}")
        raise e

    print("\n=== Testing Complete ===")


if __name__ == "__main__":
    main()

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 3, 2025
Copy link
Collaborator

@leseb leseb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits - please run pre-commits locally, thanks!

@leseb
Copy link
Collaborator

leseb commented Oct 6, 2025

@github-actions run precommit

Copy link
Contributor

github-actions bot commented Oct 6, 2025

⏳ Running pre-commit hooks on PR #3674...

🤖 Applied by @github-actions bot via pre-commit workflow
Copy link
Contributor

github-actions bot commented Oct 6, 2025

✅ Pre-commit hooks completed successfully!

🔧 Changes have been committed and pushed to the PR branch.


# Add metrics to the chunk
if self.telemetry and chunk.usage:
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks unrelated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a bug around this. I have seen a few PRs which introduced the same logic: #3392, #3422

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, my tests for watsonx.ai wouldn't succeed without this fix (as noted in the PR description). I am fine with letting some other PR put the fix in, but I think this PR should probably wait until that one is in if that's the plan.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does watsonx return usage information? we need to fix/adapt in the adapter, not in the core. putting this in the core obscures the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was already handling the case where there was no usage information by checking chunk.usage. So I would argue this change is just doing what the line was already doing but in a more robust way. FWIW, the watsonx.ai API does return a usage block, e.g.:

	"usage": {
		"completion_tokens": 54,
		"prompt_tokens": 79,
		"total_tokens": 133
	},

Notably though, it doesn't put the usage information on every chunk. It only puts it on the final chunk in the stream. I can see that when I call the streaming REST API directly. However, I don't see it when I call LiteLLM directly with streaming=True. I do see it when I call LiteLLM without streaming, FWIW. So I think LiteLLM might be dropping the usage information from the last chunk. Here is how I tested this in a notebook:

import litellm, asyncio
import nest_asyncio
nest_asyncio.apply()

async def get_litellm_response():
    return await litellm.acompletion(
        model="watsonx/meta-llama/llama-3-3-70b-instruct",
        messages=[{"role": "user", "content": "What is the capital of France?"}],
        stream=True
    )

async def print_litellm_response():
    response = await get_litellm_response()
    async for chunk in response:
        print(chunk)

asyncio.run(print_litellm_response())

With all that said, even if LiteLLM was correctly including this on the last chunk, we'd still have the issue that it is missing from all of the other chunks (unless LiteLLM put in an explicit None for this field for each other chunk). So I still think we should adopt this change here and let the line handle both missing AND explicit None.

Signed-off-by: Bill Murdock <[email protected]>
Copy link
Collaborator

@mattf mattf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

api_key=self.get_api_key(),
api_base=self.api_base,
)
logger.info(f"params to litellm (openai compat): {params}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid this, it'll not only produce a lot of output but it'll also log user data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thx! I put that in there thinking I wasn't pushing this file anyway, and then forgot to take it out when we wound up moving the b64_encode_openai_embeddings_response method in to the class. Fixed now.

@mattf
Copy link
Collaborator

mattf commented Oct 6, 2025

as a bonus, if you can record inference responses for watsonx we'll be able to make additional changes w/o needing a key

Signed-off-by: Bill Murdock <[email protected]>
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def check_model_availability(self, model):
return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this always true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that was the way to signal that we don't have a static list, but I looked more closely at how the OpenAI mixin handles this issue, and I reimplemented to mirror what it does.

@mattf mattf self-requested a review October 6, 2025 20:27
@jwm4
Copy link
Contributor Author

jwm4 commented Oct 6, 2025

Matt writes:

as a bonus, if you can record inference responses for watsonx we'll be able to make additional changes w/o needing a key

How do I do that?

jwm4 added 2 commits October 6, 2025 16:43
Signed-off-by: Bill Murdock <[email protected]>
Signed-off-by: Bill Murdock <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

watsonx.ai inference provider is not working for me
4 participants