From 2c5802898b812b2629cd35ee78d1f6f06ed91a46 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Thu, 6 Mar 2025 20:19:11 -0500 Subject: [PATCH] [Vertex AI] Parameterize integration tests for Vertex and Dev API --- FirebaseVertexAI/Sources/VertexAI.swift | 44 ++----- .../GenerateContentIntegrationTests.swift | 118 ++++++++++++++++++ .../Tests/Integration/IntegrationTests.swift | 21 ---- .../VertexAITestApp.xcodeproj/project.pbxproj | 4 + .../Tests/Unit/VertexComponentTests.swift | 41 ++++-- 5 files changed, 163 insertions(+), 65 deletions(-) create mode 100644 FirebaseVertexAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index b3404f4333a..16fc8be0561 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -25,35 +25,18 @@ import Foundation public class VertexAI { // MARK: - Public APIs - /// The default `VertexAI` instance. - /// - /// - Parameter location: The region identifier, defaulting to `us-central1`; see [Vertex AI - /// regions - /// ](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions) - /// for a list of supported regions. - /// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`. - public static func vertexAI(location: String = "us-central1") -> VertexAI { - guard let app = FirebaseApp.app() else { - fatalError("No instance of the default Firebase app was found.") - } - let vertexInstance = vertexAI(app: app, location: location) - assert(vertexInstance.apiConfig.service == .vertexAI) - assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd) - assert(vertexInstance.apiConfig.version == .v1beta) - - return vertexInstance - } - - /// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`. + /// Creates an instance of `VertexAI`. /// /// - Parameters: - /// - app: The custom `FirebaseApp` used for initialization. + /// - app: A custom `FirebaseApp` used for initialization; if not specified, uses the default + /// ``FirebaseApp``. /// - location: The region identifier, defaulting to `us-central1`; see /// [Vertex AI locations] /// (https://firebase.google.com/docs/vertex-ai/locations?platform=ios#available-locations) /// for a list of supported locations. /// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`. - public static func vertexAI(app: FirebaseApp, location: String = "us-central1") -> VertexAI { + public static func vertexAI(app: FirebaseApp? = nil, + location: String = "us-central1") -> VertexAI { let vertexInstance = vertexAI(app: app, location: location, apiConfig: defaultVertexAIAPIConfig) assert(vertexInstance.apiConfig.service == .vertexAI) assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd) @@ -160,25 +143,12 @@ public class VertexAI { let location: String? static let defaultVertexAIAPIConfig = APIConfig(service: .vertexAI, version: .v1beta) - static let defaultDeveloperAPIConfig = APIConfig( - service: .developer(endpoint: .generativeLanguage), - version: .v1beta - ) - static func developerAPI(apiConfig: APIConfig = defaultDeveloperAPIConfig) -> VertexAI { - guard let app = FirebaseApp.app() else { + static func vertexAI(app: FirebaseApp?, location: String?, apiConfig: APIConfig) -> VertexAI { + guard let app = app ?? FirebaseApp.app() else { fatalError("No instance of the default Firebase app was found.") } - return developerAPI(app: app, apiConfig: apiConfig) - } - - static func developerAPI(app: FirebaseApp, - apiConfig: APIConfig = defaultDeveloperAPIConfig) -> VertexAI { - return vertexAI(app: app, location: nil, apiConfig: apiConfig) - } - - static func vertexAI(app: FirebaseApp, location: String?, apiConfig: APIConfig) -> VertexAI { os_unfair_lock_lock(&instancesLock) // Unlock before the function returns. diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift new file mode 100644 index 00000000000..bef4349fc90 --- /dev/null +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift @@ -0,0 +1,118 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseAuth +import FirebaseCore +import FirebaseStorage +import FirebaseVertexAI +import Testing +import VertexAITestApp + +@testable import struct FirebaseVertexAI.APIConfig + +@Suite(.serialized) +struct GenerateContentIntegrationTests { + static let vertexV1Config = APIConfig(service: .vertexAI, version: .v1) + static let vertexV1BetaConfig = APIConfig(service: .vertexAI, version: .v1beta) + static let developerV1BetaConfig = APIConfig( + service: .developer(endpoint: .generativeLanguage), + version: .v1beta + ) + + // Set temperature, topP and topK to lowest allowed values to make responses more deterministic. + static let generationConfig = GenerationConfig( + temperature: 0.0, + topP: 0.0, + topK: 1, + responseMIMEType: "text/plain" + ) + static let systemInstruction = ModelContent( + role: "system", + parts: "You are a friendly and helpful assistant." + ) + static let safetySettings = [ + SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove), + SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove), + SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove), + SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove), + SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove), + ] + // Candidates and total token counts may differ slightly between runs due to whitespace tokens. + let tokenCountAccuracy = 1 + + let storage: Storage + let userID1: String + + init() async throws { + let authResult = try await Auth.auth().signIn( + withEmail: Credentials.emailAddress1, + password: Credentials.emailPassword1 + ) + userID1 = authResult.user.uid + + storage = Storage.storage() + } + + @Test(arguments: [vertexV1Config, vertexV1BetaConfig, developerV1BetaConfig]) + func generateContent(_ apiConfig: APIConfig) async throws { + let model = GenerateContentIntegrationTests.model(apiConfig: apiConfig) + let prompt = "Where is Google headquarters located? Answer with the city name only." + + let response = try await model.generateContent(prompt) + + let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines) + #expect(text == "Mountain View") + + let usageMetadata = try #require(response.usageMetadata) + #expect(usageMetadata.promptTokenCount == 21) + #expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy)) + #expect(usageMetadata.totalTokenCount.isEqual(to: 24, accuracy: tokenCountAccuracy)) + #expect(usageMetadata.promptTokensDetails.count == 1) + let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first) + #expect(promptTokensDetails.modality == .text) + #expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount) + #expect(usageMetadata.candidatesTokensDetails.count == 1) + let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first) + #expect(candidatesTokensDetails.modality == .text) + #expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount) + } + + static func model(apiConfig: APIConfig) -> GenerativeModel { + return instance(apiConfig: apiConfig).generativeModel( + modelName: "gemini-2.0-flash", + generationConfig: generationConfig, + safetySettings: safetySettings, + tools: [], + toolConfig: .init(functionCallingConfig: .none()), + systemInstruction: systemInstruction + ) + } + + // TODO(andrewheard): Move this helper to a file in the Utilities folder. + static func instance(apiConfig: APIConfig) -> VertexAI { + switch apiConfig.service { + case .vertexAI: + return VertexAI.vertexAI(app: nil, location: "us-central1", apiConfig: apiConfig) + case .developer: + return VertexAI.vertexAI(app: nil, location: nil, apiConfig: apiConfig) + } + } +} + +// TODO(andrewheard): Move this extension to a file in the Utilities folder. +extension Numeric where Self: Strideable, Self.Stride.Magnitude: Comparable { + func isEqual(to other: Self, accuracy: Self.Stride) -> Bool { + return distance(to: other).magnitude < accuracy.magnitude + } +} diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift index b5bfc94b93b..4cd60cf3e76 100644 --- a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift @@ -69,27 +69,6 @@ final class IntegrationTests: XCTestCase { // MARK: - Generate Content - func testGenerateContent() async throws { - let prompt = "Where is Google headquarters located? Answer with the city name only." - - let response = try await model.generateContent(prompt) - - let text = try XCTUnwrap(response.text).trimmingCharacters(in: .whitespacesAndNewlines) - XCTAssertEqual(text, "Mountain View") - let usageMetadata = try XCTUnwrap(response.usageMetadata) - XCTAssertEqual(usageMetadata.promptTokenCount, 21) - XCTAssertEqual(usageMetadata.candidatesTokenCount, 3, accuracy: tokenCountAccuracy) - XCTAssertEqual(usageMetadata.totalTokenCount, 24, accuracy: tokenCountAccuracy) - XCTAssertEqual(usageMetadata.promptTokensDetails.count, 1) - let promptTokensDetails = try XCTUnwrap(usageMetadata.promptTokensDetails.first) - XCTAssertEqual(promptTokensDetails.modality, .text) - XCTAssertEqual(promptTokensDetails.tokenCount, usageMetadata.promptTokenCount) - XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1) - let candidatesTokensDetails = try XCTUnwrap(usageMetadata.candidatesTokensDetails.first) - XCTAssertEqual(candidatesTokensDetails.modality, .text) - XCTAssertEqual(candidatesTokensDetails.tokenCount, usageMetadata.candidatesTokenCount) - } - func testGenerateContentStream() async throws { let expectedText = """ 1. Mercury diff --git a/FirebaseVertexAI/Tests/TestApp/VertexAITestApp.xcodeproj/project.pbxproj b/FirebaseVertexAI/Tests/TestApp/VertexAITestApp.xcodeproj/project.pbxproj index b2b2b4f643f..ea5c7a45531 100644 --- a/FirebaseVertexAI/Tests/TestApp/VertexAITestApp.xcodeproj/project.pbxproj +++ b/FirebaseVertexAI/Tests/TestApp/VertexAITestApp.xcodeproj/project.pbxproj @@ -22,6 +22,7 @@ 8692F29E2CC9477800539E8F /* FirebaseVertexAI in Frameworks */ = {isa = PBXBuildFile; productRef = 8692F29D2CC9477800539E8F /* FirebaseVertexAI */; }; 8698D7462CD3CF3600ABA833 /* FirebaseAppTestUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */; }; 8698D7482CD4332B00ABA833 /* TestAppCheckProviderFactory.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8698D7472CD4332B00ABA833 /* TestAppCheckProviderFactory.swift */; }; + 86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -49,6 +50,7 @@ 868A7C552CCC271300E449DD /* TestApp.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = TestApp.entitlements; sourceTree = ""; }; 8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FirebaseAppTestUtils.swift; sourceTree = ""; }; 8698D7472CD4332B00ABA833 /* TestAppCheckProviderFactory.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestAppCheckProviderFactory.swift; sourceTree = ""; }; + 86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerateContentIntegrationTests.swift; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -126,6 +128,7 @@ children = ( 868A7C4D2CCC1F4700E449DD /* Credentials.swift */, 8661386D2CC943DE00F4B78E /* IntegrationTests.swift */, + 86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */, 864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */, 862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */, ); @@ -273,6 +276,7 @@ 868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */, 864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */, 862218812D04E098007ED2D4 /* IntegrationTestUtils.swift in Sources */, + 86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */, 8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; diff --git a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift index 2d7c1ec567f..857b9e024cf 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift @@ -51,8 +51,22 @@ class VertexComponentTests: XCTestCase { XCTAssertNotNil(NSClassFromString("FIRVertexAIComponent")) } - /// Tests that a vertex instance can be created properly using the default Firebase pp. + /// Tests that a vertex instance can be created properly using the default Firebase app. func testVertexInstanceCreation_defaultApp() throws { + let vertex = VertexAI.vertexAI() + + XCTAssertNotNil(vertex) + XCTAssertEqual(vertex.firebaseInfo.projectID, VertexComponentTests.projectID) + XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey) + XCTAssertEqual(vertex.location, "us-central1") + XCTAssertEqual(vertex.apiConfig.service, .vertexAI) + XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseVertexAIProd) + XCTAssertEqual(vertex.apiConfig.version, .v1beta) + } + + /// Tests that a vertex instance can be created properly using the default Firebase app and custom + /// location. + func testVertexInstanceCreation_defaultApp_customLocation() throws { let vertex = VertexAI.vertexAI(location: location) XCTAssertNotNil(vertex) @@ -121,8 +135,16 @@ class VertexComponentTests: XCTestCase { } func testSameAppAndDifferentAPI_newInstanceCreated() throws { - let vertex1 = VertexAI.vertexAI(app: VertexComponentTests.app) - let vertex2 = VertexAI.developerAPI(app: VertexComponentTests.app) + let vertex1 = VertexAI.vertexAI( + app: VertexComponentTests.app, + location: location, + apiConfig: APIConfig(service: .vertexAI, version: .v1beta) + ) + let vertex2 = VertexAI.vertexAI( + app: VertexComponentTests.app, + location: location, + apiConfig: APIConfig(service: .vertexAI, version: .v1) + ) // Ensure they are different instances. XCTAssert(vertex1 !== vertex2) @@ -168,7 +190,8 @@ class VertexComponentTests: XCTestCase { func testModelResourceName_developerAPI_generativeLanguage() throws { let app = try XCTUnwrap(VertexComponentTests.app) - let vertex = VertexAI.developerAPI(app: app) + let apiConfig = APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta) + let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig) let model = "test-model-name" let modelResourceName = vertex.modelResourceName(modelName: model) @@ -182,7 +205,7 @@ class VertexComponentTests: XCTestCase { service: .developer(endpoint: .firebaseVertexAIStaging), version: .v1beta ) - let vertex = VertexAI.developerAPI(app: app, apiConfig: apiConfig) + let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig) let model = "test-model-name" let projectID = vertex.firebaseInfo.projectID @@ -208,7 +231,11 @@ class VertexComponentTests: XCTestCase { func testGenerativeModel_developerAPI() async throws { let app = try XCTUnwrap(VertexComponentTests.app) - let vertex = VertexAI.developerAPI(app: app) + let apiConfig = APIConfig( + service: .developer(endpoint: .firebaseVertexAIStaging), + version: .v1beta + ) + let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig) let modelResourceName = vertex.modelResourceName(modelName: modelName) let generativeModel = vertex.generativeModel( @@ -218,6 +245,6 @@ class VertexComponentTests: XCTestCase { XCTAssertEqual(generativeModel.modelResourceName, modelResourceName) XCTAssertEqual(generativeModel.systemInstruction, systemInstruction) - XCTAssertEqual(generativeModel.apiConfig, VertexAI.defaultDeveloperAPIConfig) + XCTAssertEqual(generativeModel.apiConfig, apiConfig) } }