Skip to content

[AQUA] Refactor evaluation service config to remove redundant information. #1105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 5 additions & 181 deletions ads/aqua/config/evaluation/evaluation_service_config.py
Original file line number Diff line number Diff line change
@@ -1,157 +1,23 @@
#!/usr/bin/env python

# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from copy import deepcopy
from typing import Any, Dict, List, Optional

from pydantic import Field

from ads.aqua.config.utils.serializer import Serializable


class ModelParamsOverrides(Serializable):
"""Defines overrides for model parameters, including exclusions and additional inclusions."""

exclude: Optional[List[str]] = Field(default_factory=list)
include: Optional[Dict[str, Any]] = Field(default_factory=dict)

class Config:
extra = "ignore"


class ModelParamsVersion(Serializable):
"""Handles version-specific model parameter overrides."""

overrides: Optional[ModelParamsOverrides] = Field(
default_factory=ModelParamsOverrides
)

class Config:
extra = "ignore"


class ModelParamsContainer(Serializable):
"""Represents a container's model configuration, including tasks, defaults, and versions."""

name: Optional[str] = None
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
versions: Optional[Dict[str, ModelParamsVersion]] = Field(default_factory=dict)

class Config:
extra = "ignore"


class InferenceParams(Serializable):
"""Contains inference-related parameters with defaults."""

class Config:
extra = "allow"


class InferenceContainer(Serializable):
"""Represents the inference parameters specific to a container."""

name: Optional[str] = None
params: Optional[Dict[str, Any]] = Field(default_factory=dict)

class Config:
extra = "ignore"


class ReportParams(Serializable):
"""Handles the report-related parameters."""

default: Optional[Dict[str, Any]] = Field(default_factory=dict)

class Config:
extra = "ignore"


class InferenceParamsConfig(Serializable):
"""Combines default inference parameters with container-specific configurations."""

default: Optional[InferenceParams] = Field(default_factory=InferenceParams)
containers: Optional[List[InferenceContainer]] = Field(default_factory=list)

def get_merged_params(self, container_name: str) -> InferenceParams:
"""
Merges default inference params with those specific to the given container.

Parameters
----------
container_name (str): The name of the container.

Returns
-------
InferenceParams: The merged inference parameters.
"""
merged_params = self.default.to_dict()
for containers in self.containers:
if containers.name.lower() == container_name.lower():
merged_params.update(containers.params or {})
break
return InferenceParams(**merged_params)

class Config:
extra = "ignore"


class InferenceModelParamsConfig(Serializable):
"""Encapsulates the model parameters for different containers."""

default: Optional[Dict[str, Any]] = Field(default_factory=dict)
containers: Optional[List[ModelParamsContainer]] = Field(default_factory=list)

def get_merged_model_params(
self,
container_name: str,
version: Optional[str] = None,
) -> Dict[str, Any]:
"""
Gets the model parameters for a given container, version,
merged with the defaults.

Parameters
----------
container_name (str): The name of the container.
version (Optional[str]): The specific version of the container.

Returns
-------
Dict[str, Any]: The merged model parameters.
"""
params = deepcopy(self.default)

for container in self.containers:
if container.name.lower() == container_name.lower():
params.update(container.default)

if version and version in container.versions:
version_overrides = container.versions[version].overrides
if version_overrides:
if version_overrides.include:
params.update(version_overrides.include)
if version_overrides.exclude:
for key in version_overrides.exclude:
params.pop(key, None)
break

return params

class Config:
extra = "ignore"


class ShapeFilterConfig(Serializable):
"""Represents the filtering options for a specific shape."""

evaluation_container: Optional[List[str]] = Field(default_factory=list)
evaluation_target: Optional[List[str]] = Field(default_factory=list)

class Config:
extra = "ignore"
extra = "allow"


class ShapeConfig(Serializable):
Expand All @@ -178,7 +44,7 @@ class MetricConfig(Serializable):
tags: Optional[List[str]] = Field(default_factory=list)

class Config:
extra = "ignore"
extra = "allow"


class ModelParamsConfig(Serializable):
Expand Down Expand Up @@ -223,7 +89,7 @@ def search_shapes(
]

class Config:
extra = "ignore"
extra = "allow"
protected_namespaces = ()


Expand All @@ -235,49 +101,7 @@ class EvaluationServiceConfig(Serializable):

version: Optional[str] = "1.0"
kind: Optional[str] = "evaluation_service_config"
report_params: Optional[ReportParams] = Field(default_factory=ReportParams)
inference_params: Optional[InferenceParamsConfig] = Field(
default_factory=InferenceParamsConfig
)
inference_model_params: Optional[InferenceModelParamsConfig] = Field(
default_factory=InferenceModelParamsConfig
)
ui_config: Optional[UIConfig] = Field(default_factory=UIConfig)

def get_merged_inference_params(self, container_name: str) -> InferenceParams:
"""
Merges default inference params with those specific to the given container.

Params
------
container_name (str): The name of the container.

Returns
-------
InferenceParams: The merged inference parameters.
"""
return self.inference_params.get_merged_params(container_name=container_name)

def get_merged_inference_model_params(
self,
container_name: str,
version: Optional[str] = None,
) -> Dict[str, Any]:
"""
Gets the model parameters for a given container, version, and task, merged with the defaults.

Parameters
----------
container_name (str): The name of the container.
version (Optional[str]): The specific version of the container.

Returns
-------
Dict[str, Any]: The merged model parameters.
"""
return self.inference_model_params.get_merged_model_params(
container_name=container_name, version=version
)

class Config:
extra = "ignore"
extra = "allow"

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,111 +1,5 @@
{
"inference_model_params": {
"containers": [
{
"default": {
"add_generation_prompt": false
},
"name": "odsc-vllm-serving",
"versions": {
"0.5.1": {
"overrides": {
"exclude": [
"max_tokens",
"frequency_penalty"
],
"include": {
"some_other_param": "some_other_param_value"
}
}
},
"0.5.3.post1": {
"overrides": {
"exclude": [
"add_generation_prompt"
],
"include": {}
}
}
}
},
{
"default": {
"add_generation_prompt": false
},
"name": "odsc-tgi-serving",
"versions": {
"2.0.1.4": {
"overrides": {
"exclude": [
"max_tokens",
"frequency_penalty"
],
"include": {
"some_other_param": "some_other_param_value"
}
}
}
}
},
{
"default": {
"add_generation_prompt": false
},
"name": "odsc-llama-cpp-serving",
"versions": {
"0.2.78.0": {
"overrides": {
"exclude": [],
"include": {}
}
}
}
}
],
"default": {
"add_generation_prompt": false,
"frequency_penalty": 0.0,
"max_tokens": 500,
"model": "odsc-llm",
"presence_penalty": 0.0,
"some_default_param": "some_default_param",
"stop": [],
"temperature": 0.7,
"top_k": 50,
"top_p": 0.9
}
},
"inference_params": {
"containers": [
{
"name": "odsc-vllm-serving",
"params": {}
},
{
"name": "odsc-tgi-serving",
"params": {}
},
{
"name": "odsc-llama-cpp-serving",
"params": {
"inference_delay": 1,
"inference_max_threads": 1
}
}
],
"default": {
"inference_backoff_factor": 3,
"inference_delay": 0,
"inference_max_threads": 10,
"inference_retries": 3,
"inference_rps": 25,
"inference_timeout": 120
}
},
"kind": "evaluation_service_config",
"report_params": {
"default": {}
},
"ui_config": {
"metrics": [
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
{
"inference_model_params": {
"containers": [],
"default": {}
},
"inference_params": {
"containers": [],
"default": {}
},
"kind": "evaluation_service_config",
"report_params": {
"default": {}
},
"ui_config": {
"metrics": [],
"model_params": {
Expand Down
Loading
Loading