diff --git a/README.md b/README.md index ba05d843..d7537387 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,7 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow | `connectionString` | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | | `logPath` | Folder to store logs. | | `disabledTools` | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | When set to true, only allows read and metadata operation types, disabling create/update/delete operations. | +| `readOnly` | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | | `indexCheck` | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | | `telemetry` | When set to disabled, disables telemetry collection. | @@ -318,10 +318,11 @@ Operation types: - `delete` - Tools that delete resources, such as delete document, drop collection, etc. - `read` - Tools that read resources, such as find, aggregate, list clusters, etc. - `metadata` - Tools that read metadata, such as list databases, list collections, collection schema, etc. +- `connect` - Tools that allow you to connect or switch the connection to a MongoDB instance. If this is disabled, you will need to provide a connection string through the config when starting the server. #### Read-Only Mode -The `readOnly` configuration option allows you to restrict the MCP server to only use tools with "read" and "metadata" operation types. When enabled, all tools that have "create", "update" or "delete" operation types will not be registered with the server. +The `readOnly` configuration option allows you to restrict the MCP server to only use tools with "read", "connect", and "metadata" operation types. When enabled, all tools that have "create", "update" or "delete" operation types will not be registered with the server. This is useful for scenarios where you want to provide access to MongoDB data for analysis without allowing any modifications to the data or infrastructure. diff --git a/src/server.ts b/src/server.ts index 31a99ded..c32dc367 100644 --- a/src/server.ts +++ b/src/server.ts @@ -12,6 +12,7 @@ import { type ServerCommand } from "./telemetry/types.js"; import { CallToolRequestSchema, CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import assert from "assert"; import { detectContainerEnv } from "./common/container.js"; +import { ToolBase } from "./tools/tool.js"; export interface ServerOptions { session: Session; @@ -22,9 +23,10 @@ export interface ServerOptions { export class Server { public readonly session: Session; - private readonly mcpServer: McpServer; + public readonly mcpServer: McpServer; private readonly telemetry: Telemetry; public readonly userConfig: UserConfig; + public readonly tools: ToolBase[] = []; private readonly startTime: number; constructor({ session, mcpServer, userConfig, telemetry }: ServerOptions) { @@ -141,8 +143,11 @@ export class Server { } private registerTools() { - for (const tool of [...AtlasTools, ...MongoDbTools]) { - new tool(this.session, this.userConfig, this.telemetry).register(this.mcpServer); + for (const toolConstructor of [...AtlasTools, ...MongoDbTools]) { + const tool = new toolConstructor(this.session, this.userConfig, this.telemetry); + if (tool.register(this)) { + this.tools.push(tool); + } } } diff --git a/src/tools/atlas/atlasTool.ts b/src/tools/atlas/atlasTool.ts index 2b93a5ec..eb7c2f1f 100644 --- a/src/tools/atlas/atlasTool.ts +++ b/src/tools/atlas/atlasTool.ts @@ -6,7 +6,7 @@ import { z } from "zod"; import { ApiClientError } from "../../common/atlas/apiClientError.js"; export abstract class AtlasToolBase extends ToolBase { - protected category: ToolCategory = "atlas"; + public category: ToolCategory = "atlas"; protected verifyAllowed(): boolean { if (!this.config.apiClientId || !this.config.apiClientSecret) { @@ -29,7 +29,7 @@ export abstract class AtlasToolBase extends ToolBase { type: "text", text: `Unable to authenticate with MongoDB Atlas, API error: ${error.message} -Hint: Your API credentials may be invalid, expired or lack permissions. +Hint: Your API credentials may be invalid, expired or lack permissions. Please check your Atlas API credentials and ensure they have the appropriate permissions. For more information on setting up API keys, visit: https://www.mongodb.com/docs/atlas/configure-api-access/`, }, @@ -44,7 +44,7 @@ For more information on setting up API keys, visit: https://www.mongodb.com/docs { type: "text", text: `Received a Forbidden API Error: ${error.message} - + You don't have sufficient permissions to perform this action in MongoDB Atlas Please ensure your API key has the necessary roles assigned. For more information on Atlas API access roles, visit: https://www.mongodb.com/docs/atlas/api/service-accounts-overview/`, diff --git a/src/tools/atlas/metadata/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts similarity index 98% rename from src/tools/atlas/metadata/connectCluster.ts rename to src/tools/atlas/connect/connectCluster.ts index a65913a6..31113e82 100644 --- a/src/tools/atlas/metadata/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -13,9 +13,9 @@ function sleep(ms: number): Promise { } export class ConnectClusterTool extends AtlasToolBase { - protected name = "atlas-connect-cluster"; + public name = "atlas-connect-cluster"; protected description = "Connect to MongoDB Atlas cluster"; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "connect"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), clusterName: z.string().describe("Atlas cluster name"), diff --git a/src/tools/atlas/create/createAccessList.ts b/src/tools/atlas/create/createAccessList.ts index 1c38279a..4941b1e8 100644 --- a/src/tools/atlas/create/createAccessList.ts +++ b/src/tools/atlas/create/createAccessList.ts @@ -6,9 +6,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; const DEFAULT_COMMENT = "Added by Atlas MCP"; export class CreateAccessListTool extends AtlasToolBase { - protected name = "atlas-create-access-list"; + public name = "atlas-create-access-list"; protected description = "Allow Ip/CIDR ranges to access your MongoDB Atlas clusters."; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), ipAddresses: z diff --git a/src/tools/atlas/create/createDBUser.ts b/src/tools/atlas/create/createDBUser.ts index a8266a0a..fef9d513 100644 --- a/src/tools/atlas/create/createDBUser.ts +++ b/src/tools/atlas/create/createDBUser.ts @@ -6,9 +6,9 @@ import { CloudDatabaseUser, DatabaseUserRole } from "../../../common/atlas/opena import { generateSecurePassword } from "../../../common/atlas/generatePassword.js"; export class CreateDBUserTool extends AtlasToolBase { - protected name = "atlas-create-db-user"; + public name = "atlas-create-db-user"; protected description = "Create an MongoDB Atlas database user"; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), username: z.string().describe("Username for the new user"), diff --git a/src/tools/atlas/create/createFreeCluster.ts b/src/tools/atlas/create/createFreeCluster.ts index 2d93ae80..ed04409b 100644 --- a/src/tools/atlas/create/createFreeCluster.ts +++ b/src/tools/atlas/create/createFreeCluster.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { ClusterDescription20240805 } from "../../../common/atlas/openapi.js"; export class CreateFreeClusterTool extends AtlasToolBase { - protected name = "atlas-create-free-cluster"; + public name = "atlas-create-free-cluster"; protected description = "Create a free MongoDB Atlas cluster"; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectId: z.string().describe("Atlas project ID to create the cluster in"), name: z.string().describe("Name of the cluster"), diff --git a/src/tools/atlas/create/createProject.ts b/src/tools/atlas/create/createProject.ts index cdf71b9c..29bff3f6 100644 --- a/src/tools/atlas/create/createProject.ts +++ b/src/tools/atlas/create/createProject.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { Group } from "../../../common/atlas/openapi.js"; export class CreateProjectTool extends AtlasToolBase { - protected name = "atlas-create-project"; + public name = "atlas-create-project"; protected description = "Create a MongoDB Atlas project"; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectName: z.string().optional().describe("Name for the new project"), organizationId: z.string().optional().describe("Organization ID for the new project"), diff --git a/src/tools/atlas/read/inspectAccessList.ts b/src/tools/atlas/read/inspectAccessList.ts index 94c85228..13e027c9 100644 --- a/src/tools/atlas/read/inspectAccessList.ts +++ b/src/tools/atlas/read/inspectAccessList.ts @@ -4,9 +4,9 @@ import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class InspectAccessListTool extends AtlasToolBase { - protected name = "atlas-inspect-access-list"; + public name = "atlas-inspect-access-list"; protected description = "Inspect Ip/CIDR ranges with access to your MongoDB Atlas clusters."; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), }; diff --git a/src/tools/atlas/read/inspectCluster.ts b/src/tools/atlas/read/inspectCluster.ts index c73c1b76..a4209fd5 100644 --- a/src/tools/atlas/read/inspectCluster.ts +++ b/src/tools/atlas/read/inspectCluster.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { Cluster, inspectCluster } from "../../../common/atlas/cluster.js"; export class InspectClusterTool extends AtlasToolBase { - protected name = "atlas-inspect-cluster"; + public name = "atlas-inspect-cluster"; protected description = "Inspect MongoDB Atlas cluster"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), clusterName: z.string().describe("Atlas cluster name"), diff --git a/src/tools/atlas/read/listAlerts.ts b/src/tools/atlas/read/listAlerts.ts index bbbf6f14..dcf56a63 100644 --- a/src/tools/atlas/read/listAlerts.ts +++ b/src/tools/atlas/read/listAlerts.ts @@ -4,9 +4,9 @@ import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class ListAlertsTool extends AtlasToolBase { - protected name = "atlas-list-alerts"; + public name = "atlas-list-alerts"; protected description = "List MongoDB Atlas alerts"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID to list alerts for"), }; diff --git a/src/tools/atlas/read/listClusters.ts b/src/tools/atlas/read/listClusters.ts index a8af8828..99c26fe6 100644 --- a/src/tools/atlas/read/listClusters.ts +++ b/src/tools/atlas/read/listClusters.ts @@ -11,9 +11,9 @@ import { import { formatCluster, formatFlexCluster } from "../../../common/atlas/cluster.js"; export class ListClustersTool extends AtlasToolBase { - protected name = "atlas-list-clusters"; + public name = "atlas-list-clusters"; protected description = "List MongoDB Atlas clusters"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID to filter clusters").optional(), }; diff --git a/src/tools/atlas/read/listDBUsers.ts b/src/tools/atlas/read/listDBUsers.ts index 7650cbf0..57344d65 100644 --- a/src/tools/atlas/read/listDBUsers.ts +++ b/src/tools/atlas/read/listDBUsers.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { DatabaseUserRole, UserScope } from "../../../common/atlas/openapi.js"; export class ListDBUsersTool extends AtlasToolBase { - protected name = "atlas-list-db-users"; + public name = "atlas-list-db-users"; protected description = "List MongoDB Atlas database users"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID to filter DB users"), }; diff --git a/src/tools/atlas/read/listOrgs.ts b/src/tools/atlas/read/listOrgs.ts index c55738d7..66b4c968 100644 --- a/src/tools/atlas/read/listOrgs.ts +++ b/src/tools/atlas/read/listOrgs.ts @@ -3,9 +3,9 @@ import { AtlasToolBase } from "../atlasTool.js"; import { OperationType } from "../../tool.js"; export class ListOrganizationsTool extends AtlasToolBase { - protected name = "atlas-list-orgs"; + public name = "atlas-list-orgs"; protected description = "List MongoDB Atlas organizations"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = {}; protected async execute(): Promise { diff --git a/src/tools/atlas/read/listProjects.ts b/src/tools/atlas/read/listProjects.ts index 1a9ab523..e8fc0249 100644 --- a/src/tools/atlas/read/listProjects.ts +++ b/src/tools/atlas/read/listProjects.ts @@ -5,9 +5,9 @@ import { z } from "zod"; import { ToolArgs } from "../../tool.js"; export class ListProjectsTool extends AtlasToolBase { - protected name = "atlas-list-projects"; + public name = "atlas-list-projects"; protected description = "List MongoDB Atlas projects"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { orgId: z.string().describe("Atlas organization ID to filter projects").optional(), }; diff --git a/src/tools/atlas/tools.ts b/src/tools/atlas/tools.ts index 9c27740d..c43b88ef 100644 --- a/src/tools/atlas/tools.ts +++ b/src/tools/atlas/tools.ts @@ -8,7 +8,7 @@ import { ListDBUsersTool } from "./read/listDBUsers.js"; import { CreateDBUserTool } from "./create/createDBUser.js"; import { CreateProjectTool } from "./create/createProject.js"; import { ListOrganizationsTool } from "./read/listOrgs.js"; -import { ConnectClusterTool } from "./metadata/connectCluster.js"; +import { ConnectClusterTool } from "./connect/connectCluster.js"; import { ListAlertsTool } from "./read/listAlerts.js"; export const AtlasTools = [ diff --git a/src/tools/mongodb/metadata/connect.ts b/src/tools/mongodb/connect/connect.ts similarity index 90% rename from src/tools/mongodb/metadata/connect.ts rename to src/tools/mongodb/connect/connect.ts index 57822001..e8de9333 100644 --- a/src/tools/mongodb/metadata/connect.ts +++ b/src/tools/mongodb/connect/connect.ts @@ -2,11 +2,11 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import assert from "assert"; import { UserConfig } from "../../../config.js"; import { Telemetry } from "../../../telemetry/telemetry.js"; import { Session } from "../../../session.js"; +import { Server } from "../../../server.js"; const disconnectedSchema = z .object({ @@ -33,7 +33,7 @@ const connectedDescription = const disconnectedDescription = "Connect to a MongoDB instance"; export class ConnectTool extends MongoDBToolBase { - protected name: typeof connectedName | typeof disconnectedName = disconnectedName; + public name: typeof connectedName | typeof disconnectedName = disconnectedName; protected description: typeof connectedDescription | typeof disconnectedDescription = disconnectedDescription; // Here the default is empty just to trigger registration, but we're going to override it with the correct @@ -42,7 +42,7 @@ export class ConnectTool extends MongoDBToolBase { connectionString: z.string().optional(), }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "connect"; constructor(session: Session, config: UserConfig, telemetry: Telemetry) { super(session, config, telemetry); @@ -72,10 +72,13 @@ export class ConnectTool extends MongoDBToolBase { }; } - public register(server: McpServer): void { - super.register(server); + public register(server: Server): boolean { + if (super.register(server)) { + this.updateMetadata(); + return true; + } - this.updateMetadata(); + return false; } private updateMetadata(): void { diff --git a/src/tools/mongodb/create/createCollection.ts b/src/tools/mongodb/create/createCollection.ts index 27eaa9f5..0b1c65a7 100644 --- a/src/tools/mongodb/create/createCollection.ts +++ b/src/tools/mongodb/create/createCollection.ts @@ -3,12 +3,12 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { OperationType, ToolArgs } from "../../tool.js"; export class CreateCollectionTool extends MongoDBToolBase { - protected name = "create-collection"; + public name = "create-collection"; protected description = "Creates a new collection in a database. If the database doesn't exist, it will be created automatically."; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected async execute({ collection, database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/create/createIndex.ts b/src/tools/mongodb/create/createIndex.ts index beffaf86..8e393f04 100644 --- a/src/tools/mongodb/create/createIndex.ts +++ b/src/tools/mongodb/create/createIndex.ts @@ -5,7 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { IndexDirection } from "mongodb"; export class CreateIndexTool extends MongoDBToolBase { - protected name = "create-index"; + public name = "create-index"; protected description = "Create an index for a collection"; protected argsShape = { ...DbOperationArgs, @@ -13,7 +13,7 @@ export class CreateIndexTool extends MongoDBToolBase { name: z.string().optional().describe("The name of the index"), }; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected async execute({ database, diff --git a/src/tools/mongodb/create/insertMany.ts b/src/tools/mongodb/create/insertMany.ts index f28d79d5..4744e344 100644 --- a/src/tools/mongodb/create/insertMany.ts +++ b/src/tools/mongodb/create/insertMany.ts @@ -4,7 +4,7 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class InsertManyTool extends MongoDBToolBase { - protected name = "insert-many"; + public name = "insert-many"; protected description = "Insert an array of documents into a MongoDB collection"; protected argsShape = { ...DbOperationArgs, @@ -14,7 +14,7 @@ export class InsertManyTool extends MongoDBToolBase { "The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany()" ), }; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected async execute({ database, diff --git a/src/tools/mongodb/delete/deleteMany.ts b/src/tools/mongodb/delete/deleteMany.ts index 0257d167..aa135512 100644 --- a/src/tools/mongodb/delete/deleteMany.ts +++ b/src/tools/mongodb/delete/deleteMany.ts @@ -5,7 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; export class DeleteManyTool extends MongoDBToolBase { - protected name = "delete-many"; + public name = "delete-many"; protected description = "Removes all documents that match the filter from a MongoDB collection"; protected argsShape = { ...DbOperationArgs, @@ -16,7 +16,7 @@ export class DeleteManyTool extends MongoDBToolBase { "The query filter, specifying the deletion criteria. Matches the syntax of the filter argument of db.collection.deleteMany()" ), }; - protected operationType: OperationType = "delete"; + public operationType: OperationType = "delete"; protected async execute({ database, diff --git a/src/tools/mongodb/delete/dropCollection.ts b/src/tools/mongodb/delete/dropCollection.ts index ac914f75..f555df04 100644 --- a/src/tools/mongodb/delete/dropCollection.ts +++ b/src/tools/mongodb/delete/dropCollection.ts @@ -3,13 +3,13 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class DropCollectionTool extends MongoDBToolBase { - protected name = "drop-collection"; + public name = "drop-collection"; protected description = "Removes a collection or view from the database. The method also removes any indexes associated with the dropped collection."; protected argsShape = { ...DbOperationArgs, }; - protected operationType: OperationType = "delete"; + public operationType: OperationType = "delete"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/delete/dropDatabase.ts b/src/tools/mongodb/delete/dropDatabase.ts index b10862b2..01967265 100644 --- a/src/tools/mongodb/delete/dropDatabase.ts +++ b/src/tools/mongodb/delete/dropDatabase.ts @@ -3,12 +3,12 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class DropDatabaseTool extends MongoDBToolBase { - protected name = "drop-database"; + public name = "drop-database"; protected description = "Removes the specified database, deleting the associated data files"; protected argsShape = { database: DbOperationArgs.database, }; - protected operationType: OperationType = "delete"; + public operationType: OperationType = "delete"; protected async execute({ database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/collectionSchema.ts b/src/tools/mongodb/metadata/collectionSchema.ts index f0145323..693b8f91 100644 --- a/src/tools/mongodb/metadata/collectionSchema.ts +++ b/src/tools/mongodb/metadata/collectionSchema.ts @@ -4,11 +4,11 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { getSimplifiedSchema } from "mongodb-schema"; export class CollectionSchemaTool extends MongoDBToolBase { - protected name = "collection-schema"; + public name = "collection-schema"; protected description = "Describe the schema for a collection"; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/collectionStorageSize.ts b/src/tools/mongodb/metadata/collectionStorageSize.ts index 127e7172..7a37499a 100644 --- a/src/tools/mongodb/metadata/collectionStorageSize.ts +++ b/src/tools/mongodb/metadata/collectionStorageSize.ts @@ -3,11 +3,11 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class CollectionStorageSizeTool extends MongoDBToolBase { - protected name = "collection-storage-size"; + public name = "collection-storage-size"; protected description = "Gets the size of the collection"; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/dbStats.ts b/src/tools/mongodb/metadata/dbStats.ts index a8c0ea0d..ee819c55 100644 --- a/src/tools/mongodb/metadata/dbStats.ts +++ b/src/tools/mongodb/metadata/dbStats.ts @@ -4,13 +4,13 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { EJSON } from "bson"; export class DbStatsTool extends MongoDBToolBase { - protected name = "db-stats"; + public name = "db-stats"; protected description = "Returns statistics that reflect the use state of a single database"; protected argsShape = { database: DbOperationArgs.database, }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts index 1068a008..a686d9cc 100644 --- a/src/tools/mongodb/metadata/explain.ts +++ b/src/tools/mongodb/metadata/explain.ts @@ -8,7 +8,7 @@ import { FindArgs } from "../read/find.js"; import { CountArgs } from "../read/count.js"; export class ExplainTool extends MongoDBToolBase { - protected name = "explain"; + public name = "explain"; protected description = "Returns statistics describing the execution of the winning plan chosen by the query optimizer for the evaluated method"; @@ -34,7 +34,7 @@ export class ExplainTool extends MongoDBToolBase { .describe("The method and its arguments to run"), }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; static readonly defaultVerbosity = ExplainVerbosity.queryPlanner; diff --git a/src/tools/mongodb/metadata/listCollections.ts b/src/tools/mongodb/metadata/listCollections.ts index 193d0465..9611d541 100644 --- a/src/tools/mongodb/metadata/listCollections.ts +++ b/src/tools/mongodb/metadata/listCollections.ts @@ -3,13 +3,13 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class ListCollectionsTool extends MongoDBToolBase { - protected name = "list-collections"; + public name = "list-collections"; protected description = "List all collections for a given database"; protected argsShape = { database: DbOperationArgs.database, }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/listDatabases.ts b/src/tools/mongodb/metadata/listDatabases.ts index fe324f07..400f275b 100644 --- a/src/tools/mongodb/metadata/listDatabases.ts +++ b/src/tools/mongodb/metadata/listDatabases.ts @@ -4,10 +4,10 @@ import * as bson from "bson"; import { OperationType } from "../../tool.js"; export class ListDatabasesTool extends MongoDBToolBase { - protected name = "list-databases"; + public name = "list-databases"; protected description = "List all databases for a MongoDB connection"; protected argsShape = {}; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute(): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/logs.ts b/src/tools/mongodb/metadata/logs.ts index 9056aa59..899738fd 100644 --- a/src/tools/mongodb/metadata/logs.ts +++ b/src/tools/mongodb/metadata/logs.ts @@ -4,7 +4,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { z } from "zod"; export class LogsTool extends MongoDBToolBase { - protected name = "mongodb-logs"; + public name = "mongodb-logs"; protected description = "Returns the most recent logged mongod events"; protected argsShape = { type: z @@ -24,7 +24,7 @@ export class LogsTool extends MongoDBToolBase { .describe("The maximum number of log entries to return."), }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ type, limit }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index fe996a38..2e5c68c7 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -4,6 +4,7 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { ErrorCodes, MongoDBError } from "../../errors.js"; import logger, { LogId } from "../../logger.js"; +import { Server } from "../../server.js"; export const DbOperationArgs = { database: z.string().describe("Database name"), @@ -11,7 +12,8 @@ export const DbOperationArgs = { }; export abstract class MongoDBToolBase extends ToolBase { - protected category: ToolCategory = "mongodb"; + private server?: Server; + public category: ToolCategory = "mongodb"; protected async ensureConnected(): Promise { if (!this.session.serviceProvider) { @@ -43,11 +45,28 @@ export abstract class MongoDBToolBase extends ToolBase { return this.session.serviceProvider; } + public register(server: Server): boolean { + this.server = server; + return super.register(server); + } + protected handleError( error: unknown, args: ToolArgs ): Promise | CallToolResult { if (error instanceof MongoDBError) { + const connectTools = this.server?.tools + .filter((t) => t.operationType === "connect") + .sort((a, b) => a.category.localeCompare(b.category)); // Sort Altas tools before MongoDB tools + + // Find the first Atlas connect tool if available and suggest to the LLM to use it. + // Note: if we ever have multiple Atlas connect tools, we may want to refine this logic to select the most appropriate one. + const atlasConnectTool = connectTools?.find((t) => t.category === "atlas"); + const llmConnectHint = atlasConnectTool + ? `Note to LLM: prefer using the "${atlasConnectTool.name}" tool to connect to an Atlas cluster over using a connection string. Make sure to ask the user to specify a cluster name they want to connect to or ask them if they want to use the "list-clusters" tool to list all their clusters. Do not invent cluster names or connection strings unless the user has explicitly specified them. If they've previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same cluster/connection.` + : "Note to LLM: do not invent connection strings and explicitly ask the user to provide one. If they have previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same connection string."; + + const connectToolsNames = connectTools?.map((t) => `"${t.name}"`).join(", "); switch (error.code) { case ErrorCodes.NotConnectedToMongoDB: return { @@ -58,7 +77,9 @@ export abstract class MongoDBToolBase extends ToolBase { }, { type: "text", - text: "Please use the 'connect' or 'switch-connection' tool to connect to a MongoDB instance.", + text: connectToolsNames + ? `Please use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance or update the MCP server configuration to include a connection string. ${llmConnectHint}` + : "There are no tools available to connect. Please update the configuration to include a connection string and restart the server.", }, ], isError: true, @@ -68,7 +89,13 @@ export abstract class MongoDBToolBase extends ToolBase { content: [ { type: "text", - text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance. Alternatively, use the 'switch-connection' tool to connect to a different instance.", + text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance.", + }, + { + type: "text", + text: connectTools + ? `Alternatively, you can use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance. ${llmConnectHint}` + : "Please update the configuration to use a valid connection string and restart the server.", }, ], isError: true, diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index aa21fc5d..f9868dba 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -10,13 +10,13 @@ export const AggregateArgs = { }; export class AggregateTool extends MongoDBToolBase { - protected name = "aggregate"; + public name = "aggregate"; protected description = "Run an aggregation against a MongoDB collection"; protected argsShape = { ...DbOperationArgs, ...AggregateArgs, }; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, diff --git a/src/tools/mongodb/read/collectionIndexes.ts b/src/tools/mongodb/read/collectionIndexes.ts index cc0a141b..ef3fa75d 100644 --- a/src/tools/mongodb/read/collectionIndexes.ts +++ b/src/tools/mongodb/read/collectionIndexes.ts @@ -3,10 +3,10 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class CollectionIndexesTool extends MongoDBToolBase { - protected name = "collection-indexes"; + public name = "collection-indexes"; protected description = "Describe the indexes for a collection"; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/read/count.ts b/src/tools/mongodb/read/count.ts index 0ed3a192..df3664b5 100644 --- a/src/tools/mongodb/read/count.ts +++ b/src/tools/mongodb/read/count.ts @@ -14,7 +14,7 @@ export const CountArgs = { }; export class CountTool extends MongoDBToolBase { - protected name = "count"; + public name = "count"; protected description = "Gets the number of documents in a MongoDB collection using db.collection.count() and query as an optional filter parameter"; protected argsShape = { @@ -22,7 +22,7 @@ export class CountTool extends MongoDBToolBase { ...CountArgs, }; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, collection, query }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/read/find.ts b/src/tools/mongodb/read/find.ts index 97c90e08..02c337ed 100644 --- a/src/tools/mongodb/read/find.ts +++ b/src/tools/mongodb/read/find.ts @@ -23,13 +23,13 @@ export const FindArgs = { }; export class FindTool extends MongoDBToolBase { - protected name = "find"; + public name = "find"; protected description = "Run a find query against a MongoDB collection"; protected argsShape = { ...DbOperationArgs, ...FindArgs, }; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, diff --git a/src/tools/mongodb/tools.ts b/src/tools/mongodb/tools.ts index d64d53ea..c74fdf29 100644 --- a/src/tools/mongodb/tools.ts +++ b/src/tools/mongodb/tools.ts @@ -1,4 +1,4 @@ -import { ConnectTool } from "./metadata/connect.js"; +import { ConnectTool } from "./connect/connect.js"; import { ListCollectionsTool } from "./metadata/listCollections.js"; import { CollectionIndexesTool } from "./read/collectionIndexes.js"; import { ListDatabasesTool } from "./metadata/listDatabases.js"; diff --git a/src/tools/mongodb/update/renameCollection.ts b/src/tools/mongodb/update/renameCollection.ts index d3b07c15..e5bffbdb 100644 --- a/src/tools/mongodb/update/renameCollection.ts +++ b/src/tools/mongodb/update/renameCollection.ts @@ -4,14 +4,14 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class RenameCollectionTool extends MongoDBToolBase { - protected name = "rename-collection"; + public name = "rename-collection"; protected description = "Renames a collection in a MongoDB database"; protected argsShape = { ...DbOperationArgs, newName: z.string().describe("The new name for the collection"), dropTarget: z.boolean().optional().default(false).describe("If true, drops the target collection if it exists"), }; - protected operationType: OperationType = "update"; + public operationType: OperationType = "update"; protected async execute({ database, diff --git a/src/tools/mongodb/update/updateMany.ts b/src/tools/mongodb/update/updateMany.ts index 7392135b..b31a843e 100644 --- a/src/tools/mongodb/update/updateMany.ts +++ b/src/tools/mongodb/update/updateMany.ts @@ -5,7 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; export class UpdateManyTool extends MongoDBToolBase { - protected name = "update-many"; + public name = "update-many"; protected description = "Updates all documents that match the specified filter for a collection"; protected argsShape = { ...DbOperationArgs, @@ -23,7 +23,7 @@ export class UpdateManyTool extends MongoDBToolBase { .optional() .describe("Controls whether to insert a new document if no documents match the filter"), }; - protected operationType: OperationType = "update"; + public operationType: OperationType = "update"; protected async execute({ database, diff --git a/src/tools/tool.ts b/src/tools/tool.ts index b7cce354..551374d6 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -1,15 +1,16 @@ import { z, type ZodRawShape, type ZodNever, AnyZodObject } from "zod"; -import type { McpServer, RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { CallToolResult, ToolAnnotations } from "@modelcontextprotocol/sdk/types.js"; import { Session } from "../session.js"; import logger, { LogId } from "../logger.js"; import { Telemetry } from "../telemetry/telemetry.js"; import { type ToolEvent } from "../telemetry/types.js"; import { UserConfig } from "../config.js"; +import { Server } from "../server.js"; export type ToolArgs = z.objectOutputType; -export type OperationType = "metadata" | "read" | "create" | "delete" | "update"; +export type OperationType = "metadata" | "read" | "create" | "delete" | "update" | "connect"; export type ToolCategory = "mongodb" | "atlas"; export type TelemetryToolMetadata = { projectId?: string; @@ -17,11 +18,11 @@ export type TelemetryToolMetadata = { }; export abstract class ToolBase { - protected abstract name: string; + public abstract name: string; - protected abstract category: ToolCategory; + public abstract category: ToolCategory; - protected abstract operationType: OperationType; + public abstract operationType: OperationType; protected abstract description: string; @@ -36,6 +37,7 @@ export abstract class ToolBase { switch (this.operationType) { case "read": case "metadata": + case "connect": annotations.readOnlyHint = true; annotations.destructiveHint = false; break; @@ -63,9 +65,9 @@ export abstract class ToolBase { protected readonly telemetry: Telemetry ) {} - public register(server: McpServer): void { + public register(server: Server): boolean { if (!this.verifyAllowed()) { - return; + return false; } const callback: ToolCallback = async (...args) => { @@ -84,14 +86,15 @@ export abstract class ToolBase { } }; - server.tool(this.name, this.description, this.argsShape, this.annotations, callback); + server.mcpServer.tool(this.name, this.description, this.argsShape, this.annotations, callback); // This is very similar to RegisteredTool.update, but without the bugs around the name. // In the upstream update method, the name is captured in the closure and not updated when // the tool name changes. This means that you only get one name update before things end up // in a broken state. + // See https://github.com/modelcontextprotocol/typescript-sdk/issues/414 for more details. this.update = (updates: { name?: string; description?: string; inputSchema?: AnyZodObject }) => { - const tools = server["_registeredTools"] as { [toolName: string]: RegisteredTool }; + const tools = server.mcpServer["_registeredTools"] as { [toolName: string]: RegisteredTool }; const existingTool = tools[this.name]; if (!existingTool) { @@ -118,8 +121,10 @@ export abstract class ToolBase { existingTool.inputSchema = updates.inputSchema; } - server.sendToolListChanged(); + server.mcpServer.sendToolListChanged(); }; + + return true; } protected update?: (updates: { name?: string; description?: string; inputSchema?: AnyZodObject }) => void; diff --git a/tests/integration/tools/atlas/clusters.test.ts b/tests/integration/tools/atlas/clusters.test.ts index 62bd422c..b5f34bdf 100644 --- a/tests/integration/tools/atlas/clusters.test.ts +++ b/tests/integration/tools/atlas/clusters.test.ts @@ -1,5 +1,5 @@ import { Session } from "../../../../src/session.js"; -import { expectDefined } from "../../helpers.js"; +import { expectDefined, getResponseElements } from "../../helpers.js"; import { describeWithAtlas, withProject, randomId } from "./atlasHelpers.js"; import { ClusterDescription20240805 } from "../../../../src/common/atlas/openapi.js"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; @@ -205,6 +205,23 @@ describeWithAtlas("clusters", (integration) => { await sleep(500); } }); + + describe("when not connected", () => { + it("prompts for atlas-connect-cluster when querying mongodb", async () => { + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { database: "some-db", collection: "some-collection" }, + }); + const elements = getResponseElements(response.content); + expect(elements).toHaveLength(2); + expect(elements[0]?.text).toContain( + "You need to connect to a MongoDB instance before you can access its data." + ); + expect(elements[1]?.text).toContain( + 'Please use one of the following tools: "atlas-connect-cluster", "connect" to connect to a MongoDB instance' + ); + }); + }); }); }); }); diff --git a/tests/integration/tools/mongodb/metadata/connect.test.ts b/tests/integration/tools/mongodb/connect/connect.test.ts similarity index 82% rename from tests/integration/tools/mongodb/metadata/connect.test.ts rename to tests/integration/tools/mongodb/connect/connect.test.ts index 47e91d13..857b5747 100644 --- a/tests/integration/tools/mongodb/metadata/connect.test.ts +++ b/tests/integration/tools/mongodb/connect/connect.test.ts @@ -1,9 +1,15 @@ import { describeWithMongoDB } from "../mongodbHelpers.js"; -import { getResponseContent, validateThrowsForInvalidArguments, validateToolMetadata } from "../../../helpers.js"; +import { + getResponseContent, + getResponseElements, + validateThrowsForInvalidArguments, + validateToolMetadata, +} from "../../../helpers.js"; import { config } from "../../../../../src/config.js"; +import { defaultTestConfig, setupIntegrationTest } from "../../../helpers.js"; describeWithMongoDB( - "switchConnection tool", + "SwitchConnection tool", (integration) => { beforeEach(() => { integration.mcpServer().userConfig.connectionString = integration.connectionString(); @@ -77,6 +83,7 @@ describeWithMongoDB( connectionString: mdbIntegration.connectionString(), }) ); + describeWithMongoDB( "Connect tool", (integration) => { @@ -126,3 +133,26 @@ describeWithMongoDB( }, () => config ); + +describe("Connect tool when disabled", () => { + const integration = setupIntegrationTest(() => ({ + ...defaultTestConfig, + disabledTools: ["connect"], + })); + + it("is not suggested when querying MongoDB disconnected", async () => { + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { database: "some-db", collection: "some-collection" }, + }); + + const elements = getResponseElements(response); + expect(elements).toHaveLength(2); + expect(elements[0]?.text).toContain( + "You need to connect to a MongoDB instance before you can access its data." + ); + expect(elements[1]?.text).toContain( + "There are no tools available to connect. Please update the configuration to include a connection string and restart the server." + ); + }); +});