diff --git a/README.md b/README.md index 91959a7d..e1a74264 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,26 @@ To use debug mode with the Claude Desktop configuration, update your config as f > Note: As with the standard configuration, the `-t stdio` argument is required to override the default SSE mode in the Docker image. +### AAD Configuration +If you wish to use AAD auth to authenticate to a grafana instance. + + **If using the binary:** + + ```json + { + "mcpServers": { + "grafana": { + "command": "mcp-grafana", + "args": [], + "env": { + "GRAFANA_URL": "https://.grafana.azure.com/", + "USE_AAD_AUTH": "true", + } + } + } + } + ``` + ### TLS Configuration If your Grafana instance is behind mTLS or requires custom TLS certificates, you can configure the MCP server to use custom certificates. The server supports the following TLS configuration options: diff --git a/go.mod b/go.mod index 110a77a9..4869ee63 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.24.2 require ( connectrpc.com/connect v1.18.1 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 github.com/go-openapi/runtime v0.28.0 github.com/go-openapi/strfmt v0.23.0 github.com/google/uuid v1.6.0 @@ -23,6 +24,9 @@ require ( ) require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/BurntSushi/toml v1.4.0 // indirect github.com/apache/arrow-go/v18 v18.2.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect @@ -50,6 +54,7 @@ require ( github.com/go-openapi/validate v0.24.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/flatbuffers v25.2.10+incompatible // indirect github.com/google/go-cmp v0.7.0 // indirect @@ -73,6 +78,7 @@ require ( github.com/jszwedko/go-datemath v0.1.1-0.20230526204004-640a500621d6 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/magefile/mage v1.15.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattetti/filebuffer v1.0.1 // indirect @@ -93,6 +99,7 @@ require ( github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pierrec/lz4/v4 v4.1.22 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.2 // indirect @@ -122,6 +129,7 @@ require ( go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect go.opentelemetry.io/proto/otlp v1.5.0 // indirect + golang.org/x/crypto v0.38.0 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.40.0 // indirect diff --git a/go.sum b/go.sum index c4c89501..acc15354 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,17 @@ connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= connectrpc.com/connect v1.18.1/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= @@ -34,6 +46,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= @@ -81,6 +95,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= @@ -150,6 +166,8 @@ github.com/jszwedko/go-datemath v0.1.1-0.20230526204004-640a500621d6/go.mod h1:W github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= @@ -221,6 +239,8 @@ github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -236,6 +256,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/prometheus v0.304.1 h1:e4kpJMb2Vh/PcR6LInake+ofcvFYHT+bCfmBvOkaZbY= github.com/prometheus/prometheus v0.304.1/go.mod h1:ioGx2SGKTY+fLnJSQCdTHqARVldGNS8OlIe3kvp98so= +github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= +github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -326,6 +348,8 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -356,6 +380,7 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/mcpgrafana.go b/mcpgrafana.go index 18ab2e66..46c31740 100644 --- a/mcpgrafana.go +++ b/mcpgrafana.go @@ -10,7 +10,10 @@ import ( "net/url" "os" "strings" + "sync" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/go-openapi/strfmt" "github.com/grafana/grafana-openapi-client-go/client" "github.com/grafana/incident-go" @@ -23,9 +26,13 @@ const ( grafanaURLEnvVar = "GRAFANA_URL" grafanaAPIEnvVar = "GRAFANA_API_KEY" + useAADEnvVar = "USE_AAD_AUTH" grafanaURLHeader = "X-Grafana-URL" grafanaAPIKeyHeader = "X-Grafana-API-Key" + + // Default AAD Resources + defaultGrafanaAADResource = "ce34e7e5-485f-4d76-964f-b3d2b16d1e4f" ) func urlAndAPIKeyFromEnv() (string, string) { @@ -40,6 +47,28 @@ func urlAndAPIKeyFromHeaders(req *http.Request) (string, string) { return u, apiKey } +// isAADEnabled checks if AAD authentication is enabled via environment variable +func isAADEnabled() bool { + val := os.Getenv(useAADEnvVar) + return strings.ToLower(val) == "true" || val == "1" +} + +// createAADConfigFromEnv creates a DefaultAzureCredential instance +// if AAD authentication is enabled via environment variable. +func createAADConfigFromEnv() *azidentity.DefaultAzureCredential { + if !isAADEnabled() { + return nil + } + // Create a new DefaultAzureCredential instance + // This will use the environment variables set by Azure CLI or Managed Identity + // if available, otherwise it will fall back to other authentication methods. + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + panic(fmt.Errorf("failed to create Azure AD credential: %w", err)) + } + return cred +} + // grafanaConfigKey is the context key for Grafana configuration. type grafanaConfigKey struct{} @@ -72,6 +101,8 @@ type GrafanaConfig struct { // TLSConfig holds TLS configuration for all Grafana clients. TLSConfig *TLSConfig + + AADCredential *azidentity.DefaultAzureCredential // credential to use Azure AD authentication for Grafana API calls. } // WithGrafanaConfig adds Grafana configuration to the context. @@ -156,6 +187,8 @@ var ExtractGrafanaInfoFromEnv server.StdioContextFunc = func(ctx context.Context config := GrafanaConfigFromContext(ctx) config.URL = u config.APIKey = apiKey + config.AADCredential = createAADConfigFromEnv() + return WithGrafanaConfig(ctx, config) } @@ -184,6 +217,7 @@ var ExtractGrafanaInfoFromHeaders httpContextFunc = func(ctx context.Context, re config := GrafanaConfigFromContext(ctx) config.URL = u config.APIKey = apiKey + config.AADCredential = createAADConfigFromEnv() return WithGrafanaConfig(ctx, config) } @@ -209,7 +243,7 @@ func MustWithOnBehalfOfAuth(ctx context.Context, accessToken, userToken string) return ctx } -type grafanaClientKey struct{} +type grafanaClientFunctorKey struct{} func makeBasePath(path string) string { return strings.Join([]string{strings.TrimRight(path, "/"), "api"}, "/") @@ -274,8 +308,12 @@ var ExtractGrafanaClientFromEnv server.StdioContextFunc = func(ctx context.Conte } apiKey := os.Getenv(grafanaAPIEnvVar) - grafanaClient := NewGrafanaClient(ctx, grafanaURL, apiKey) - return context.WithValue(ctx, grafanaClientKey{}, grafanaClient) + if isAADEnabled() { + return WithAADGrafanaClientFunc(ctx) + } else { + grafanaClient := NewGrafanaClient(ctx, grafanaURL, apiKey) + return WithGrafanaClient(ctx, grafanaClient) + } } // ExtractGrafanaClientFromHeaders is a HTTPContextFunc that extracts Grafana configuration @@ -294,24 +332,68 @@ var ExtractGrafanaClientFromHeaders httpContextFunc = func(ctx context.Context, apiKey = apiKeyEnv } - grafanaClient := NewGrafanaClient(ctx, u, apiKey) - return WithGrafanaClient(ctx, grafanaClient) + if isAADEnabled() { + return WithAADGrafanaClientFunc(ctx) + } else { + grafanaClient := NewGrafanaClient(ctx, u, apiKey) + return WithGrafanaClient(ctx, grafanaClient) + } } // WithGrafanaClient sets the Grafana client in the context. // // It can be retrieved using GrafanaClientFromContext. -func WithGrafanaClient(ctx context.Context, client *client.GrafanaHTTPAPI) context.Context { - return context.WithValue(ctx, grafanaClientKey{}, client) +func WithGrafanaClient(ctx context.Context, clientParam *client.GrafanaHTTPAPI) context.Context { + return context.WithValue(ctx, grafanaClientFunctorKey{}, func() (*client.GrafanaHTTPAPI, error) { return clientParam, nil }) +} + +func WithAADGrafanaClientFunc(ctx context.Context) context.Context { + var mutex sync.Mutex // protects the cached token and cached client + config := GrafanaConfigFromContext(ctx) + cred := config.AADCredential + cachedToken, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{defaultGrafanaAADResource}, + }) + if err != nil { + panic(fmt.Errorf("failed to get AAD token for Grafana: %w", err)) + } + cachedGrafanaClient := NewGrafanaClient(ctx, config.URL, cachedToken.Token) + + slog.Debug("Constructed cached Grafana client with AAD authentication") + var functor func() (*client.GrafanaHTTPAPI, error) = func() (*client.GrafanaHTTPAPI, error) { + + funcCred, funcErr := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{defaultGrafanaAADResource}, + }) + // TODO: do error handling here, and update functor to return errors + if funcErr != nil { + return nil, fmt.Errorf("failed to get AAD token for Grafana: %w", funcErr) + } + // Use the cached client if the token is still valid, otherwise create a new one + mutex.Lock() + defer mutex.Unlock() + if cachedToken == funcCred { + // If the cached token is still valid, return the cached client + return cachedGrafanaClient, nil + } + + slog.Debug("Cached client didn't match need to refresh it.") + + cachedToken = funcCred + cachedGrafanaClient = NewGrafanaClient(ctx, config.URL, cachedToken.Token) + return cachedGrafanaClient, nil + } + + return context.WithValue(ctx, grafanaClientFunctorKey{}, functor) } // GrafanaClientFromContext retrieves the Grafana client from the context. -func GrafanaClientFromContext(ctx context.Context) *client.GrafanaHTTPAPI { - c, ok := ctx.Value(grafanaClientKey{}).(*client.GrafanaHTTPAPI) +func GrafanaClientFromContext(ctx context.Context) (*client.GrafanaHTTPAPI, error) { + c, ok := ctx.Value(grafanaClientFunctorKey{}).(func() (*client.GrafanaHTTPAPI, error)) if !ok { - return nil + return nil, fmt.Errorf("grafana client not found in context") } - return c + return c() } type incidentClientKey struct{} diff --git a/mcpgrafana_test.go b/mcpgrafana_test.go index 6aeef302..cb29e51e 100644 --- a/mcpgrafana_test.go +++ b/mcpgrafana_test.go @@ -123,7 +123,8 @@ func TestExtractGrafanaClientPath(t *testing.T) { t.Setenv("GRAFANA_URL", "http://my-test-url.grafana.com/") ctx := ExtractGrafanaClientFromEnv(context.Background()) - c := GrafanaClientFromContext(ctx) + c, err := GrafanaClientFromContext(ctx) + require.NoError(t, err) require.NotNil(t, c) rt := c.Transport.(*client.Runtime) assert.Equal(t, "/api", rt.BasePath) @@ -133,7 +134,8 @@ func TestExtractGrafanaClientPath(t *testing.T) { t.Setenv("GRAFANA_URL", "http://my-test-url.grafana.com/grafana") ctx := ExtractGrafanaClientFromEnv(context.Background()) - c := GrafanaClientFromContext(ctx) + c, err := GrafanaClientFromContext(ctx) + require.NoError(t, err) require.NotNil(t, c) rt := c.Transport.(*client.Runtime) assert.Equal(t, "/grafana/api", rt.BasePath) @@ -143,7 +145,8 @@ func TestExtractGrafanaClientPath(t *testing.T) { t.Setenv("GRAFANA_URL", "http://my-test-url.grafana.com/grafana/") ctx := ExtractGrafanaClientFromEnv(context.Background()) - c := GrafanaClientFromContext(ctx) + c, err := GrafanaClientFromContext(ctx) + require.NoError(t, err) require.NotNil(t, c) rt := c.Transport.(*client.Runtime) assert.Equal(t, "/grafana/api", rt.BasePath) @@ -167,7 +170,8 @@ func TestExtractGrafanaClientFromHeaders(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) require.NoError(t, err) ctx := ExtractGrafanaClientFromHeaders(context.Background(), req) - c := GrafanaClientFromContext(ctx) + c, err := GrafanaClientFromContext(ctx) + require.NoError(t, err) url := minURLFromClient(c) assert.Equal(t, "localhost:3000", url.host) assert.Equal(t, "/api", url.basePath) @@ -179,7 +183,8 @@ func TestExtractGrafanaClientFromHeaders(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) require.NoError(t, err) ctx := ExtractGrafanaClientFromHeaders(context.Background(), req) - c := GrafanaClientFromContext(ctx) + c, err := GrafanaClientFromContext(ctx) + require.NoError(t, err) url := minURLFromClient(c) assert.Equal(t, "my-test-url.grafana.com", url.host) assert.Equal(t, "/api", url.basePath) @@ -190,7 +195,8 @@ func TestExtractGrafanaClientFromHeaders(t *testing.T) { require.NoError(t, err) req.Header.Set(grafanaURLHeader, "http://my-test-url.grafana.com") ctx := ExtractGrafanaClientFromHeaders(context.Background(), req) - c := GrafanaClientFromContext(ctx) + c, err := GrafanaClientFromContext(ctx) + require.NoError(t, err) url := minURLFromClient(c) assert.Equal(t, "my-test-url.grafana.com", url.host) assert.Equal(t, "/api", url.basePath) @@ -204,7 +210,8 @@ func TestExtractGrafanaClientFromHeaders(t *testing.T) { require.NoError(t, err) req.Header.Set(grafanaURLHeader, "http://my-test-url.grafana.com") ctx := ExtractGrafanaClientFromHeaders(context.Background(), req) - c := GrafanaClientFromContext(ctx) + c, err := GrafanaClientFromContext(ctx) + require.NoError(t, err) url := minURLFromClient(c) assert.Equal(t, "my-test-url.grafana.com", url.host) assert.Equal(t, "/api", url.basePath) diff --git a/tools/admin.go b/tools/admin.go index d4143d33..38a3ec68 100644 --- a/tools/admin.go +++ b/tools/admin.go @@ -16,7 +16,10 @@ type ListTeamsParams struct { } func listTeams(ctx context.Context, args ListTeamsParams) (*models.SearchTeamQueryResult, error) { - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("list teams: %w", err) + } params := teams.NewSearchTeamsParamsWithContext(ctx) if args.Query != "" { params.SetQuery(&args.Query) diff --git a/tools/alerting.go b/tools/alerting.go index 6bf93a37..4bf614c2 100644 --- a/tools/alerting.go +++ b/tools/alerting.go @@ -171,7 +171,10 @@ func getAlertRuleByUID(ctx context.Context, args GetAlertRuleByUIDParams) (*mode return nil, fmt.Errorf("get alert rule by uid: %w", err) } - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("get alert rule by uid %s: %w", args.UID, err) + } alertRule, err := c.Provisioning.GetAlertRule(args.UID) if err != nil { return nil, fmt.Errorf("get alert rule by uid %s: %w", args.UID, err) @@ -211,7 +214,10 @@ func listContactPoints(ctx context.Context, args ListContactPointsParams) ([]con return nil, fmt.Errorf("list contact points: %w", err) } - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("list contact points: %w", err) + } params := provisioning.NewGetContactpointsParams().WithContext(ctx) if args.Name != nil { diff --git a/tools/dashboard.go b/tools/dashboard.go index 541397a1..95e73263 100644 --- a/tools/dashboard.go +++ b/tools/dashboard.go @@ -16,7 +16,10 @@ type GetDashboardByUIDParams struct { } func getDashboardByUID(ctx context.Context, args GetDashboardByUIDParams) (*models.DashboardFullWithMeta, error) { - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("get dashboard by uid %s: %w", args.UID, err) + } dashboard, err := c.Dashboards.GetDashboardByUID(args.UID) if err != nil { return nil, fmt.Errorf("get dashboard by uid %s: %w", args.UID, err) @@ -36,7 +39,10 @@ type UpdateDashboardParams struct { // DISCLAIMER: Large-sized dashboard JSON can exhaust context windows. We will // implement features that address this in https://github.com/grafana/mcp-grafana/issues/101. func updateDashboard(ctx context.Context, args UpdateDashboardParams) (*models.PostDashboardOKBody, error) { - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("update dashboard: %w", err) + } cmd := &models.SaveDashboardCommand{ Dashboard: args.Dashboard, FolderUID: args.FolderUID, diff --git a/tools/datasources.go b/tools/datasources.go index effc470d..fde88843 100644 --- a/tools/datasources.go +++ b/tools/datasources.go @@ -25,7 +25,10 @@ type dataSourceSummary struct { } func listDatasources(ctx context.Context, args ListDatasourcesParams) ([]dataSourceSummary, error) { - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("list datasources: %w", err) + } resp, err := c.Datasources.GetDataSources() if err != nil { return nil, fmt.Errorf("list datasources: %w", err) @@ -78,7 +81,10 @@ type GetDatasourceByUIDParams struct { } func getDatasourceByUID(ctx context.Context, args GetDatasourceByUIDParams) (*models.DataSource, error) { - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("get datasource by uid %s: %w", args.UID, err) + } datasource, err := c.Datasources.GetDataSourceByUID(args.UID) if err != nil { // Check if it's a 404 Not Found Error @@ -104,7 +110,10 @@ type GetDatasourceByNameParams struct { } func getDatasourceByName(ctx context.Context, args GetDatasourceByNameParams) (*models.DataSource, error) { - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("get datasource by name %s: %w", args.Name, err) + } datasource, err := c.Datasources.GetDataSourceByName(args.Name) if err != nil { return nil, fmt.Errorf("get datasource by name %s: %w", args.Name, err) diff --git a/tools/prometheus.go b/tools/prometheus.go index 6be84665..f40ad420 100644 --- a/tools/prometheus.go +++ b/tools/prometheus.go @@ -3,11 +3,13 @@ package tools import ( "context" "fmt" + "log/slog" "net/http" "regexp" "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/grafana/grafana-plugin-sdk-go/backend/gtime" mcpgrafana "github.com/grafana/mcp-grafana" "github.com/mark3labs/mcp-go/mcp" @@ -19,6 +21,10 @@ import ( "github.com/prometheus/prometheus/model/labels" ) +const ( + defaultPrometheusAADResource = "https://prometheus.monitor.azure.com" +) + var ( matchTypeMap = map[string]labels.MatchType{ "": labels.MatchEqual, @@ -31,7 +37,7 @@ var ( func promClientFromContext(ctx context.Context, uid string) (promv1.API, error) { // First check if the datasource exists - _, err := getDatasourceByUID(ctx, GetDatasourceByUIDParams{UID: uid}) + datasource, err := getDatasourceByUID(ctx, GetDatasourceByUIDParams{UID: uid}) if err != nil { return nil, err } @@ -39,6 +45,32 @@ func promClientFromContext(ctx context.Context, uid string) (promv1.API, error) cfg := mcpgrafana.GrafanaConfigFromContext(ctx) url := fmt.Sprintf("%s/api/datasources/proxy/uid/%s", strings.TrimRight(cfg.URL, "/"), uid) + // Check if datasource has azureAuthType in JSONData + requiresAADAuth := false + if datasource.JSONData != nil { + if jsonDataMap, ok := datasource.JSONData.(map[string]interface{}); ok { + if _, hasAzureAuthType := jsonDataMap["azureAuthType"]; hasAzureAuthType { + requiresAADAuth = true + slog.Debug("Prometheus datasource requires AAD auth (azureAuthType)", "uid", uid) + } else if _, hasAzureAuth := jsonDataMap["azureAuth"]; hasAzureAuth { + requiresAADAuth = true + slog.Debug("Prometheus datasource requires AAD auth (azureAuth)", "uid", uid) + } + } + } + + // if using AADCredential and requiresAADAuth, go to prometheus url + if cfg.AADCredential != nil && requiresAADAuth { + url = datasource.URL + cred, err := cfg.AADCredential.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{defaultPrometheusAADResource}, + }) + if err != nil { + return nil, fmt.Errorf("getting AAD token for Prometheus: %w", err) + } + cfg.APIKey = cred.Token + } + // Create custom transport with TLS configuration if available rt := api.DefaultRoundTripper if tlsConfig := cfg.TLSConfig; tlsConfig != nil { diff --git a/tools/search.go b/tools/search.go index 8825ac0c..3308101d 100644 --- a/tools/search.go +++ b/tools/search.go @@ -19,7 +19,10 @@ type SearchDashboardsParams struct { } func searchDashboards(ctx context.Context, args SearchDashboardsParams) (models.HitList, error) { - c := mcpgrafana.GrafanaClientFromContext(ctx) + c, err := mcpgrafana.GrafanaClientFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("search dashboards: %w", err) + } params := search.NewSearchParamsWithContext(ctx) if args.Query != "" { params.SetQuery(&args.Query)