Skip to content

Commit d852209

Browse files
authored
Added cli to convert legacy fine tuned model to v2. (#1241)
1 parent 71c908c commit d852209

File tree

3 files changed

+204
-0
lines changed

3 files changed

+204
-0
lines changed

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
AQUA_TROUBLESHOOTING_LINK = "https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/troubleshooting-tips.md"
4646
MODEL_FILE_DESCRIPTION_VERSION = "1.0"
4747
MODEL_FILE_DESCRIPTION_TYPE = "modelOSSReferenceDescription"
48+
AQUA_FINE_TUNE_MODEL_VERSION = "v2"
4849

4950
TRAINING_METRICS_FINAL = "training_metrics_final"
5051
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/model/model.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from ads.aqua.config.container_config import AquaContainerConfig
4545
from ads.aqua.constants import (
46+
AQUA_FINE_TUNE_MODEL_VERSION,
4647
AQUA_MODEL_ARTIFACT_CONFIG,
4748
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
4849
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
@@ -645,6 +646,89 @@ def edit_registered_model(
645646
else:
646647
raise AquaRuntimeError("Only registered unverified models can be edited.")
647648

649+
def convert_fine_tune(
650+
self, model_id: str, delete_model: Optional[bool] = False
651+
) -> DataScienceModel:
652+
"""Converts legacy fine tuned model to fine tuned model v2.
653+
1. 'fine_tune_model_version' tag will be added as 'v2' to new fine tuned model.
654+
2. 'model_file_description' json will only contain fine tuned artifacts for new fine tuned model.
655+
656+
Parameters
657+
----------
658+
model_id: str
659+
The legacy fine tuned model OCID.
660+
delete_model: bool
661+
Flag whether to delete the legacy model or not. Defaults to False.
662+
663+
Returns
664+
-------
665+
DataScienceModel:
666+
The instance of DataScienceModel.
667+
"""
668+
legacy_fine_tuned_model = DataScienceModel.from_id(model_id)
669+
legacy_tags = legacy_fine_tuned_model.freeform_tags or {}
670+
671+
if (
672+
Tags.AQUA_TAG not in legacy_tags
673+
or Tags.AQUA_FINE_TUNED_MODEL_TAG not in legacy_tags
674+
):
675+
raise AquaValueError(
676+
f"Model '{model_id}' is not eligible for conversion. Only legacy AQUA fine-tuned models "
677+
f"without the 'fine_tune_model_version={AQUA_FINE_TUNE_MODEL_VERSION}' tag are supported."
678+
)
679+
680+
if (
681+
legacy_tags.get(Tags.AQUA_FINE_TUNE_MODEL_VERSION, UNKNOWN).lower()
682+
== AQUA_FINE_TUNE_MODEL_VERSION
683+
):
684+
raise AquaValueError(
685+
f"Model '{model_id}' is already a fine-tuned model in version '{AQUA_FINE_TUNE_MODEL_VERSION}'. "
686+
"No conversion is necessary."
687+
)
688+
689+
if not legacy_fine_tuned_model.model_file_description:
690+
raise AquaValueError(
691+
f"Model '{model_id}' is missing required metadata and cannot be converted. "
692+
"This may indicate the model was not created properly or is not a supported legacy AQUA fine-tuned model."
693+
)
694+
695+
# add 'fine_tune_model_version' tag as 'v2'
696+
fine_tune_model_v2_tags = {
697+
**legacy_tags,
698+
Tags.AQUA_FINE_TUNE_MODEL_VERSION: AQUA_FINE_TUNE_MODEL_VERSION,
699+
}
700+
701+
# remove base model artifacts in 'model_file_description' json file
702+
# base model artifacts are placed as the first entry in 'models' list
703+
legacy_fine_tuned_model.model_file_description["models"].pop(0)
704+
705+
fine_tune_model_v2 = (
706+
DataScienceModel()
707+
.with_compartment_id(legacy_fine_tuned_model.compartment_id)
708+
.with_project_id(legacy_fine_tuned_model.project_id)
709+
.with_model_file_description(
710+
json_dict=legacy_fine_tuned_model.model_file_description
711+
)
712+
.with_display_name(legacy_fine_tuned_model.display_name)
713+
.with_description(legacy_fine_tuned_model.description)
714+
.with_freeform_tags(**fine_tune_model_v2_tags)
715+
.with_defined_tags(**(legacy_fine_tuned_model.defined_tags or {}))
716+
.with_custom_metadata_list(legacy_fine_tuned_model.custom_metadata_list)
717+
.with_defined_metadata_list(legacy_fine_tuned_model.defined_metadata_list)
718+
.with_provenance_metadata(legacy_fine_tuned_model.provenance_metadata)
719+
.create(model_by_reference=True)
720+
)
721+
722+
logger.info(
723+
f"Successfully created version '{AQUA_FINE_TUNE_MODEL_VERSION}' fine-tuned model: '{fine_tune_model_v2.id}' "
724+
f"based on legacy model '{model_id}'. This new model is now ready for deployment."
725+
)
726+
727+
if delete_model:
728+
legacy_fine_tuned_model.delete()
729+
730+
return fine_tune_model_v2
731+
648732
def _fetch_metric_from_metadata(
649733
self,
650734
custom_metadata_list: ModelCustomMetadata,

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,125 @@ def test_get_model_fine_tuned(
825825
"evaluation_container": "odsc-llm-evaluate",
826826
}
827827

828+
@patch.object(DataScienceModel, "create")
829+
@patch.object(DataScienceModel, "from_id")
830+
def test_convert_fine_tune(self, mock_from_id, mock_create):
831+
ds_model = MagicMock()
832+
ds_model.id = "test_id"
833+
ds_model.compartment_id = "test_model_compartment_id"
834+
ds_model.project_id = "test_project_id"
835+
ds_model.display_name = "test_display_name"
836+
ds_model.description = "test_description"
837+
ds_model.model_version_set_id = "test_model_version_set_id"
838+
ds_model.model_version_set_name = "test_model_version_set_name"
839+
ds_model.freeform_tags = {
840+
"license": "test_license",
841+
"organization": "test_organization",
842+
"task": "test_task",
843+
"aqua_fine_tuned_model": "test_finetuned_model",
844+
}
845+
ds_model.time_created = "2024-01-19T17:57:39.158000+00:00"
846+
ds_model.lifecycle_state = "ACTIVE"
847+
custom_metadata_list = ModelCustomMetadata()
848+
custom_metadata_list.add(
849+
**{"key": "artifact_location", "value": "oci://bucket@namespace/prefix/"}
850+
)
851+
custom_metadata_list.add(
852+
**{"key": "fine_tune_source", "value": "test_fine_tuned_source_id"}
853+
)
854+
custom_metadata_list.add(
855+
**{"key": "fine_tune_source_name", "value": "test_fine_tuned_source_name"}
856+
)
857+
custom_metadata_list.add(
858+
**{
859+
"key": "deployment-container",
860+
"value": "odsc-vllm-serving",
861+
}
862+
)
863+
custom_metadata_list.add(
864+
**{
865+
"key": "evaluation-container",
866+
"value": "odsc-llm-evaluate",
867+
}
868+
)
869+
custom_metadata_list.add(
870+
**{
871+
"key": "finetune-container",
872+
"value": "odsc-llm-fine-tuning",
873+
}
874+
)
875+
ds_model.custom_metadata_list = custom_metadata_list
876+
defined_metadata_list = ModelTaxonomyMetadata()
877+
defined_metadata_list["Hyperparameters"].value = {
878+
"training_data": "test_training_data",
879+
"val_set_size": "test_val_set_size",
880+
}
881+
ds_model.defined_metadata_list = defined_metadata_list
882+
ds_model.provenance_metadata = ModelProvenanceMetadata(
883+
training_id="test_training_job_run_id"
884+
)
885+
ds_model.model_file_description = {
886+
"version": "1.0",
887+
"type": "modelOSSReferenceDescription",
888+
"models": [
889+
{
890+
"namespace": "test_namespace_one",
891+
"bucketName": "test_bucket_name_one",
892+
"prefix": "test_prefix_one",
893+
"objects": [
894+
{
895+
"name": "artifact/.gitattributes",
896+
"version": "123",
897+
"sizeInBytes": 1519,
898+
}
899+
],
900+
},
901+
{
902+
"namespace": "test_namespace_two",
903+
"bucketName": "test_bucket_name_two",
904+
"prefix": "test_prefix_two",
905+
"objects": [
906+
{
907+
"name": "/README.md",
908+
"version": "b52c2608-009f-4774-8325-60ec226ae003",
909+
"sizeInBytes": 5189,
910+
}
911+
],
912+
},
913+
],
914+
}
915+
916+
mock_from_id.return_value = ds_model
917+
918+
# missing 'OCI_AQUA' tag
919+
with pytest.raises(
920+
AquaValueError,
921+
match="Model 'mock_model_id' is not eligible for conversion. Only legacy AQUA fine-tuned models without the 'fine_tune_model_version=v2' tag are supported.",
922+
):
923+
self.app.convert_fine_tune(model_id="mock_model_id")
924+
925+
# add 'OCI_AQUA' tag
926+
mock_from_id.return_value.freeform_tags["OCI_AQUA"] = "ACTIVE"
927+
928+
self.app.convert_fine_tune(model_id="mock_model_id")
929+
930+
mock_create.assert_called_with(model_by_reference=True)
931+
932+
assert mock_from_id.return_value.model_file_description["models"] == [
933+
{
934+
"namespace": "test_namespace_two",
935+
"bucketName": "test_bucket_name_two",
936+
"prefix": "test_prefix_two",
937+
"objects": [
938+
{
939+
"name": "/README.md",
940+
"version": "b52c2608-009f-4774-8325-60ec226ae003",
941+
"sizeInBytes": 5189,
942+
}
943+
],
944+
}
945+
]
946+
828947
@pytest.mark.parametrize(
829948
("artifact_location_set", "download_from_hf", "cleanup_model_cache"),
830949
[

0 commit comments

Comments
 (0)