diff --git a/dev-proxy-abstractions/ILoggerExtensions.cs b/dev-proxy-abstractions/ILoggerExtensions.cs index fe02728e..ad906a08 100644 --- a/dev-proxy-abstractions/ILoggerExtensions.cs +++ b/dev-proxy-abstractions/ILoggerExtensions.cs @@ -1,24 +1,24 @@ -using System.Text.Json; -using Microsoft.DevProxy.Abstractions; - -#pragma warning disable IDE0130 -namespace Microsoft.Extensions.Logging; -#pragma warning restore IDE0130 - -public static class ILoggerExtensions -{ - public static void LogRequest(this ILogger logger, string[] message, MessageType messageType, LoggingContext? context = null) - { - logger.Log(new RequestLog(message, messageType, context)); - } - - public static void LogRequest(this ILogger logger, string[] message, MessageType messageType, string method, string url) - { - logger.Log(new RequestLog(message, messageType, method, url)); - } - - public static void Log(this ILogger logger, RequestLog message) - { - logger.Log(LogLevel.Information, 0, message, exception: null, (m, _) => JsonSerializer.Serialize(m)); - } +using System.Text.Json; +using Microsoft.DevProxy.Abstractions; + +#pragma warning disable IDE0130 +namespace Microsoft.Extensions.Logging; +#pragma warning restore IDE0130 + +public static class ILoggerExtensions +{ + public static void LogRequest(this ILogger logger, string message, MessageType messageType, LoggingContext? context = null) + { + logger.Log(new RequestLog(message, messageType, context)); + } + + public static void LogRequest(this ILogger logger, string message, MessageType messageType, string method, string url) + { + logger.Log(new RequestLog(message, messageType, method, url)); + } + + public static void Log(this ILogger logger, RequestLog message) + { + logger.Log(LogLevel.Information, 0, message, exception: null, (m, _) => JsonSerializer.Serialize(m)); + } } \ No newline at end of file diff --git a/dev-proxy-abstractions/IProxyLogger.cs b/dev-proxy-abstractions/IProxyLogger.cs index 37e23fba..1666d610 100644 --- a/dev-proxy-abstractions/IProxyLogger.cs +++ b/dev-proxy-abstractions/IProxyLogger.cs @@ -1,32 +1,26 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -using Titanium.Web.Proxy.EventArguments; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Abstractions; - -public enum MessageType -{ - Normal, - InterceptedRequest, - PassedThrough, - Warning, - Tip, - Failed, - Chaos, - Mocked, - InterceptedResponse, - FinishedProcessingRequest -} - -public class LoggingContext(SessionEventArgs session) -{ - public SessionEventArgs Session { get; } = session; -} - -public interface IProxyLogger : ICloneable, ILogger -{ - public LogLevel LogLevel { get; set; } - public void LogRequest(string[] message, MessageType messageType, LoggingContext? context = null); - public void LogRequest(string[] message, MessageType messageType, string method, string url); +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Titanium.Web.Proxy.EventArguments; + +namespace Microsoft.DevProxy.Abstractions; + +public enum MessageType +{ + Normal, + InterceptedRequest, + PassedThrough, + Warning, + Tip, + Failed, + Chaos, + Mocked, + InterceptedResponse, + FinishedProcessingRequest, + Skipped +} + +public class LoggingContext(SessionEventArgs session) +{ + public SessionEventArgs Session { get; } = session; } \ No newline at end of file diff --git a/dev-proxy-abstractions/PluginEvents.cs b/dev-proxy-abstractions/PluginEvents.cs index e432cea6..500b8fb0 100644 --- a/dev-proxy-abstractions/PluginEvents.cs +++ b/dev-proxy-abstractions/PluginEvents.cs @@ -1,273 +1,274 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.CommandLine; -using System.CommandLine.Invocation; -using System.Security.Cryptography.X509Certificates; -using System.Text.Json.Serialization; -using Microsoft.DevProxy.Abstractions.LanguageModel; -using Titanium.Web.Proxy; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Http; - -namespace Microsoft.DevProxy.Abstractions; - -public interface IProxyContext -{ - IProxyConfiguration Configuration { get; } - X509Certificate2? Certificate { get; } - ILanguageModelClient LanguageModelClient { get; } -} - -public class ThrottlerInfo(string throttlingKey, Func shouldThrottle, DateTime resetTime) -{ - /// - /// Throttling key used to identify which requests should be throttled. - /// Can be set to a hostname, full URL or a custom string value, that - /// represents for example a portion of the API - /// - public string ThrottlingKey { get; private set; } = throttlingKey ?? throw new ArgumentNullException(nameof(throttlingKey)); - /// - /// Function responsible for matching the request to the throttling key. - /// Takes as arguments: - /// - intercepted request - /// - the throttling key - /// Returns an instance of ThrottlingInfo that contains information - /// whether the request should be throttled or not. - /// - public Func ShouldThrottle { get; private set; } = shouldThrottle ?? throw new ArgumentNullException(nameof(shouldThrottle)); - /// - /// Time when the throttling window will be reset - /// - public DateTime ResetTime { get; set; } = resetTime; -} - -public class ThrottlingInfo(int throttleForSeconds, string retryAfterHeaderName) -{ - public int ThrottleForSeconds { get; set; } = throttleForSeconds; - public string RetryAfterHeaderName { get; set; } = retryAfterHeaderName ?? throw new ArgumentNullException(nameof(retryAfterHeaderName)); -} - -public class ProxyEventArgsBase -{ - public Dictionary SessionData { get; set; } = []; - public Dictionary GlobalData { get; set; } = []; -} - -public class ProxyHttpEventArgsBase : ProxyEventArgsBase -{ - internal ProxyHttpEventArgsBase(SessionEventArgs session) - { - Session = session ?? throw new ArgumentNullException(nameof(session)); - } - - public SessionEventArgs Session { get; } - - public bool HasRequestUrlMatch(ISet watchedUrls) - { - var match = watchedUrls.FirstOrDefault(r => r.Url.IsMatch(Session.HttpClient.Request.RequestUri.AbsoluteUri)); - return match is not null && !match.Exclude; - } -} - -public class ProxyRequestArgs(SessionEventArgs session, ResponseState responseState) : ProxyHttpEventArgsBase(session) -{ - public ResponseState ResponseState { get; } = responseState ?? throw new ArgumentNullException(nameof(responseState)); - - public bool ShouldExecute(ISet watchedUrls) => - !ResponseState.HasBeenSet - && HasRequestUrlMatch(watchedUrls); -} - -public class ProxyResponseArgs(SessionEventArgs session, ResponseState responseState) : ProxyHttpEventArgsBase(session) -{ - public ResponseState ResponseState { get; } = responseState ?? throw new ArgumentNullException(nameof(responseState)); -} - -public class InitArgs -{ - public InitArgs() - { - } -} - -public class OptionsLoadedArgs(InvocationContext context, Option[] options) -{ - public InvocationContext Context { get; set; } = context ?? throw new ArgumentNullException(nameof(context)); - public Option[] Options { get; set; } = options ?? throw new ArgumentNullException(nameof(options)); -} - -public class RequestLog -{ - public string[] MessageLines { get; set; } - public MessageType MessageType { get; set; } - [JsonIgnore] - public LoggingContext? Context { get; set; } - public string? Method { get; init; } - public string? Url { get; init; } - - public RequestLog(string[] messageLines, MessageType messageType, LoggingContext? context) : - this(messageLines, messageType, context?.Session.HttpClient.Request.Method, context?.Session.HttpClient.Request.Url, context) - { - } - - public RequestLog(string[] messageLines, MessageType messageType, string method, string url) : - this(messageLines, messageType, method, url, context: null) - { - } - - private RequestLog(string[] messageLines, MessageType messageType, string? method, string? url, LoggingContext? context) - { - MessageLines = messageLines ?? throw new ArgumentNullException(nameof(messageLines)); - MessageType = messageType; - Context = context; - Method = method; - Url = url; - } - - public void Deconstruct(out string[] message, out MessageType messageType, out LoggingContext? context, out string? method, out string? url) - { - message = MessageLines; - messageType = MessageType; - context = Context; - method = Method; - url = Url; - } -} - -public class RecordingArgs(IEnumerable requestLogs) : ProxyEventArgsBase -{ - public IEnumerable RequestLogs { get; set; } = requestLogs ?? throw new ArgumentNullException(nameof(requestLogs)); -} - -public class RequestLogArgs(RequestLog requestLog) -{ - public RequestLog RequestLog { get; set; } = requestLog ?? throw new ArgumentNullException(nameof(requestLog)); -} - -public interface IPluginEvents -{ - /// - /// Raised while starting the proxy, allows plugins to register command line options - /// - event EventHandler Init; - /// - /// Raised during startup after command line arguments have been parsed, - /// used to update the internal state of a plugin that registers command line options - /// - event EventHandler OptionsLoaded; - /// - /// Raised before a request is sent to the server. - /// Used to intercept requests. - /// - event AsyncEventHandler BeforeRequest; - /// - /// Raised after the response is received from the server. - /// Is not raised if a response is set during the BeforeRequest event. - /// Allows plugins to modify a response received from the server. - /// - event AsyncEventHandler BeforeResponse; - /// - /// Raised after a response is sent to the client. - /// Raised for all responses - /// - event AsyncEventHandler? AfterResponse; - /// - /// Raised after request message has been logged. - /// - event AsyncEventHandler? AfterRequestLog; - /// - /// Raised after recording request logs has stopped. - /// - event AsyncEventHandler? AfterRecordingStop; - /// - /// Raised when user requested issuing mock requests. - /// - event AsyncEventHandler? MockRequest; - - void RaiseInit(InitArgs args); - void RaiseOptionsLoaded(OptionsLoadedArgs args); - Task RaiseProxyBeforeRequestAsync(ProxyRequestArgs args, ExceptionHandler? exceptionFunc = null); - Task RaiseProxyBeforeResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null); - Task RaiseProxyAfterResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null); - Task RaiseRequestLoggedAsync(RequestLogArgs args, ExceptionHandler? exceptionFunc = null); - Task RaiseRecordingStoppedAsync(RecordingArgs args, ExceptionHandler? exceptionFunc = null); - Task RaiseMockRequestAsync(EventArgs args, ExceptionHandler? exceptionFunc = null); -} - -public class PluginEvents : IPluginEvents -{ - /// - public event EventHandler? Init; - /// - public event EventHandler? OptionsLoaded; - /// - public event AsyncEventHandler? BeforeRequest; - /// - public event AsyncEventHandler? BeforeResponse; - /// - public event AsyncEventHandler? AfterResponse; - /// - public event AsyncEventHandler? AfterRequestLog; - /// - public event AsyncEventHandler? AfterRecordingStop; - public event AsyncEventHandler? MockRequest; - - public void RaiseInit(InitArgs args) - { - Init?.Invoke(this, args); - } - - public void RaiseOptionsLoaded(OptionsLoadedArgs args) - { - OptionsLoaded?.Invoke(this, args); - } - - public async Task RaiseProxyBeforeRequestAsync(ProxyRequestArgs args, ExceptionHandler? exceptionFunc = null) - { - if (BeforeRequest is not null) - { - await BeforeRequest.InvokeAsync(this, args, exceptionFunc); - } - } - - public async Task RaiseProxyBeforeResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null) - { - if (BeforeResponse is not null) - { - await BeforeResponse.InvokeAsync(this, args, exceptionFunc); - } - } - - public async Task RaiseProxyAfterResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null) - { - if (AfterResponse is not null) - { - await AfterResponse.InvokeAsync(this, args, exceptionFunc); - } - } - - public async Task RaiseRequestLoggedAsync(RequestLogArgs args, ExceptionHandler? exceptionFunc = null) - { - if (AfterRequestLog is not null) - { - await AfterRequestLog.InvokeAsync(this, args, exceptionFunc); - } - } - - public async Task RaiseRecordingStoppedAsync(RecordingArgs args, ExceptionHandler? exceptionFunc = null) - { - if (AfterRecordingStop is not null) - { - await AfterRecordingStop.InvokeAsync(this, args, exceptionFunc); - } - } - - public async Task RaiseMockRequestAsync(EventArgs args, ExceptionHandler? exceptionFunc = null) - { - if (MockRequest is not null) - { - await MockRequest.InvokeAsync(this, args, exceptionFunc); - } - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.CommandLine.Invocation; +using System.Security.Cryptography.X509Certificates; +using System.Text.Json.Serialization; +using Microsoft.DevProxy.Abstractions.LanguageModel; +using Titanium.Web.Proxy; +using Titanium.Web.Proxy.EventArguments; +using Titanium.Web.Proxy.Http; + +namespace Microsoft.DevProxy.Abstractions; + +public interface IProxyContext +{ + IProxyConfiguration Configuration { get; } + X509Certificate2? Certificate { get; } + ILanguageModelClient LanguageModelClient { get; } +} + +public class ThrottlerInfo(string throttlingKey, Func shouldThrottle, DateTime resetTime) +{ + /// + /// Throttling key used to identify which requests should be throttled. + /// Can be set to a hostname, full URL or a custom string value, that + /// represents for example a portion of the API + /// + public string ThrottlingKey { get; private set; } = throttlingKey ?? throw new ArgumentNullException(nameof(throttlingKey)); + /// + /// Function responsible for matching the request to the throttling key. + /// Takes as arguments: + /// - intercepted request + /// - the throttling key + /// Returns an instance of ThrottlingInfo that contains information + /// whether the request should be throttled or not. + /// + public Func ShouldThrottle { get; private set; } = shouldThrottle ?? throw new ArgumentNullException(nameof(shouldThrottle)); + /// + /// Time when the throttling window will be reset + /// + public DateTime ResetTime { get; set; } = resetTime; +} + +public class ThrottlingInfo(int throttleForSeconds, string retryAfterHeaderName) +{ + public int ThrottleForSeconds { get; set; } = throttleForSeconds; + public string RetryAfterHeaderName { get; set; } = retryAfterHeaderName ?? throw new ArgumentNullException(nameof(retryAfterHeaderName)); +} + +public class ProxyEventArgsBase +{ + public Dictionary SessionData { get; set; } = []; + public Dictionary GlobalData { get; set; } = []; +} + +public class ProxyHttpEventArgsBase : ProxyEventArgsBase +{ + internal ProxyHttpEventArgsBase(SessionEventArgs session) + { + Session = session ?? throw new ArgumentNullException(nameof(session)); + } + + public SessionEventArgs Session { get; } + + public bool HasRequestUrlMatch(ISet watchedUrls) + { + var match = watchedUrls.FirstOrDefault(r => r.Url.IsMatch(Session.HttpClient.Request.RequestUri.AbsoluteUri)); + return match is not null && !match.Exclude; + } +} + +public class ProxyRequestArgs(SessionEventArgs session, ResponseState responseState) : ProxyHttpEventArgsBase(session) +{ + public ResponseState ResponseState { get; } = responseState ?? throw new ArgumentNullException(nameof(responseState)); + + public bool ShouldExecute(ISet watchedUrls) => + !ResponseState.HasBeenSet + && HasRequestUrlMatch(watchedUrls); +} + +public class ProxyResponseArgs(SessionEventArgs session, ResponseState responseState) : ProxyHttpEventArgsBase(session) +{ + public ResponseState ResponseState { get; } = responseState ?? throw new ArgumentNullException(nameof(responseState)); +} + +public class InitArgs +{ + public InitArgs() + { + } +} + +public class OptionsLoadedArgs(InvocationContext context, Option[] options) +{ + public InvocationContext Context { get; set; } = context ?? throw new ArgumentNullException(nameof(context)); + public Option[] Options { get; set; } = options ?? throw new ArgumentNullException(nameof(options)); +} + +public class RequestLog +{ + public string Message { get; set; } + public MessageType MessageType { get; set; } + [JsonIgnore] + public LoggingContext? Context { get; set; } + public string? Method { get; init; } + public string? Url { get; init; } + public string? PluginName { get; set; } + + public RequestLog(string message, MessageType messageType, LoggingContext? context) : + this(message, messageType, context?.Session.HttpClient.Request.Method, context?.Session.HttpClient.Request.Url, context) + { + } + + public RequestLog(string message, MessageType messageType, string method, string url) : + this(message, messageType, method, url, context: null) + { + } + + private RequestLog(string message, MessageType messageType, string? method, string? url, LoggingContext? context) + { + Message = message ?? throw new ArgumentNullException(nameof(message)); + MessageType = messageType; + Context = context; + Method = method; + Url = url; + } + + public void Deconstruct(out string message, out MessageType messageType, out LoggingContext? context, out string? method, out string? url) + { + message = Message; + messageType = MessageType; + context = Context; + method = Method; + url = Url; + } +} + +public class RecordingArgs(IEnumerable requestLogs) : ProxyEventArgsBase +{ + public IEnumerable RequestLogs { get; set; } = requestLogs ?? throw new ArgumentNullException(nameof(requestLogs)); +} + +public class RequestLogArgs(RequestLog requestLog) +{ + public RequestLog RequestLog { get; set; } = requestLog ?? throw new ArgumentNullException(nameof(requestLog)); +} + +public interface IPluginEvents +{ + /// + /// Raised while starting the proxy, allows plugins to register command line options + /// + event EventHandler Init; + /// + /// Raised during startup after command line arguments have been parsed, + /// used to update the internal state of a plugin that registers command line options + /// + event EventHandler OptionsLoaded; + /// + /// Raised before a request is sent to the server. + /// Used to intercept requests. + /// + event AsyncEventHandler BeforeRequest; + /// + /// Raised after the response is received from the server. + /// Is not raised if a response is set during the BeforeRequest event. + /// Allows plugins to modify a response received from the server. + /// + event AsyncEventHandler BeforeResponse; + /// + /// Raised after a response is sent to the client. + /// Raised for all responses + /// + event AsyncEventHandler? AfterResponse; + /// + /// Raised after request message has been logged. + /// + event AsyncEventHandler? AfterRequestLog; + /// + /// Raised after recording request logs has stopped. + /// + event AsyncEventHandler? AfterRecordingStop; + /// + /// Raised when user requested issuing mock requests. + /// + event AsyncEventHandler? MockRequest; + + void RaiseInit(InitArgs args); + void RaiseOptionsLoaded(OptionsLoadedArgs args); + Task RaiseProxyBeforeRequestAsync(ProxyRequestArgs args, ExceptionHandler? exceptionFunc = null); + Task RaiseProxyBeforeResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null); + Task RaiseProxyAfterResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null); + Task RaiseRequestLoggedAsync(RequestLogArgs args, ExceptionHandler? exceptionFunc = null); + Task RaiseRecordingStoppedAsync(RecordingArgs args, ExceptionHandler? exceptionFunc = null); + Task RaiseMockRequestAsync(EventArgs args, ExceptionHandler? exceptionFunc = null); +} + +public class PluginEvents : IPluginEvents +{ + /// + public event EventHandler? Init; + /// + public event EventHandler? OptionsLoaded; + /// + public event AsyncEventHandler? BeforeRequest; + /// + public event AsyncEventHandler? BeforeResponse; + /// + public event AsyncEventHandler? AfterResponse; + /// + public event AsyncEventHandler? AfterRequestLog; + /// + public event AsyncEventHandler? AfterRecordingStop; + public event AsyncEventHandler? MockRequest; + + public void RaiseInit(InitArgs args) + { + Init?.Invoke(this, args); + } + + public void RaiseOptionsLoaded(OptionsLoadedArgs args) + { + OptionsLoaded?.Invoke(this, args); + } + + public async Task RaiseProxyBeforeRequestAsync(ProxyRequestArgs args, ExceptionHandler? exceptionFunc = null) + { + if (BeforeRequest is not null) + { + await BeforeRequest.InvokeAsync(this, args, exceptionFunc); + } + } + + public async Task RaiseProxyBeforeResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null) + { + if (BeforeResponse is not null) + { + await BeforeResponse.InvokeAsync(this, args, exceptionFunc); + } + } + + public async Task RaiseProxyAfterResponseAsync(ProxyResponseArgs args, ExceptionHandler? exceptionFunc = null) + { + if (AfterResponse is not null) + { + await AfterResponse.InvokeAsync(this, args, exceptionFunc); + } + } + + public async Task RaiseRequestLoggedAsync(RequestLogArgs args, ExceptionHandler? exceptionFunc = null) + { + if (AfterRequestLog is not null) + { + await AfterRequestLog.InvokeAsync(this, args, exceptionFunc); + } + } + + public async Task RaiseRecordingStoppedAsync(RecordingArgs args, ExceptionHandler? exceptionFunc = null) + { + if (AfterRecordingStop is not null) + { + await AfterRecordingStop.InvokeAsync(this, args, exceptionFunc); + } + } + + public async Task RaiseMockRequestAsync(EventArgs args, ExceptionHandler? exceptionFunc = null) + { + if (MockRequest is not null) + { + await MockRequest.InvokeAsync(this, args, exceptionFunc); + } + } +} diff --git a/dev-proxy-plugins/Behavior/RateLimitingPlugin.cs b/dev-proxy-plugins/Behavior/RateLimitingPlugin.cs index 99e18f98..8fe05fb9 100644 --- a/dev-proxy-plugins/Behavior/RateLimitingPlugin.cs +++ b/dev-proxy-plugins/Behavior/RateLimitingPlugin.cs @@ -1,297 +1,307 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using System.Net; -using System.Text.Json; -using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; - -namespace Microsoft.DevProxy.Plugins.Behavior; - -public enum RateLimitResponseWhenLimitExceeded -{ - Throttle, - Custom -} - -public enum RateLimitResetFormat -{ - SecondsLeft, - UtcEpochSeconds -} - -public class RateLimitConfiguration -{ - public string HeaderLimit { get; set; } = "RateLimit-Limit"; - public string HeaderRemaining { get; set; } = "RateLimit-Remaining"; - public string HeaderReset { get; set; } = "RateLimit-Reset"; - public string HeaderRetryAfter { get; set; } = "Retry-After"; - public RateLimitResetFormat ResetFormat { get; set; } = RateLimitResetFormat.SecondsLeft; - public int CostPerRequest { get; set; } = 2; - public int ResetTimeWindowSeconds { get; set; } = 60; - public int WarningThresholdPercent { get; set; } = 80; - public int RateLimit { get; set; } = 120; - public RateLimitResponseWhenLimitExceeded WhenLimitExceeded { get; set; } = RateLimitResponseWhenLimitExceeded.Throttle; - public string CustomResponseFile { get; set; } = "rate-limit-response.json"; - public MockResponseResponse? CustomResponse { get; set; } -} - -public class RateLimitingPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(RateLimitingPlugin); - private readonly RateLimitConfiguration _configuration = new(); - // initial values so that we know when we intercept the - // first request and can set the initial values - private int _resourcesRemaining = -1; - private DateTime _resetTime = DateTime.MinValue; - private RateLimitingCustomResponseLoader? _loader = null; - - private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) - { - var throttleKeyForRequest = BuildThrottleKey(request); - return new ThrottlingInfo(throttleKeyForRequest == throttlingKey ? (int)(_resetTime - DateTime.Now).TotalSeconds : 0, _configuration.HeaderRetryAfter); - } - - private void ThrottleResponse(ProxyRequestArgs e) => UpdateProxyResponse(e, HttpStatusCode.TooManyRequests); - - private void UpdateProxyResponse(ProxyHttpEventArgsBase e, HttpStatusCode errorStatus) - { - var headers = new List(); - var body = string.Empty; - var request = e.Session.HttpClient.Request; - var response = e.Session.HttpClient.Response; - - // resources exceeded - if (errorStatus == HttpStatusCode.TooManyRequests) - { - if (ProxyUtils.IsGraphRequest(request)) - { - string requestId = Guid.NewGuid().ToString(); - string requestDate = DateTime.Now.ToString(); - headers.AddRange(ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate)); - - body = JsonSerializer.Serialize(new GraphErrorResponseBody( - new GraphErrorResponseError - { - Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), - Message = BuildApiErrorMessage(request), - InnerError = new GraphErrorResponseInnerError - { - RequestId = requestId, - Date = requestDate - } - }), - ProxyUtils.JsonSerializerOptions - ); - } - - headers.Add(new(_configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString())); - if (request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) - { - headers.Add(new("Access-Control-Allow-Origin", "*")); - headers.Add(new("Access-Control-Expose-Headers", _configuration.HeaderRetryAfter)); - } - - e.Session.GenericResponse(body ?? string.Empty, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value)).ToArray()); - return; - } - - if (e.SessionData.TryGetValue(Name, out var pluginData) && - pluginData is List rateLimitingHeaders) - { - ProxyUtils.MergeHeaders(headers, rateLimitingHeaders); - } - - // add headers to the original API response, avoiding duplicates - headers.ForEach(h => e.Session.HttpClient.Response.Headers.RemoveHeader(h.Name)); - e.Session.HttpClient.Response.Headers.AddHeaders(headers.Select(h => new HttpHeader(h.Name, h.Value)).ToArray()); - } - private static string BuildApiErrorMessage(Request r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : String.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage(r)) : "")}"; - - private static string BuildThrottleKey(Request r) - { - if (ProxyUtils.IsGraphRequest(r)) - { - return GraphUtils.BuildThrottleKey(r); - } - else - { - return r.RequestUri.Host; - } - } - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - if (_configuration.WhenLimitExceeded == RateLimitResponseWhenLimitExceeded.Custom) - { - _configuration.CustomResponseFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.CustomResponseFile), Path.GetDirectoryName(Context.Configuration.ConfigFile ?? string.Empty) ?? string.Empty); - _loader = new RateLimitingCustomResponseLoader(Logger, _configuration); - // load the responses from the configured mocks file - _loader.InitResponsesWatcher(); - } - - PluginEvents.BeforeRequest += OnRequestAsync; - PluginEvents.BeforeResponse += OnResponseAsync; - } - - // add rate limiting headers to the response from the API - private Task OnResponseAsync(object? sender, ProxyResponseArgs e) - { - if (UrlsToWatch is null || - !e.HasRequestUrlMatch(UrlsToWatch)) - { - return Task.CompletedTask; - } - - UpdateProxyResponse(e, HttpStatusCode.OK); - return Task.CompletedTask; - } - - private Task OnRequestAsync(object? sender, ProxyRequestArgs e) - { - var session = e.Session; - var state = e.ResponseState; - if (e.ResponseState.HasBeenSet || - UrlsToWatch is null || - !e.ShouldExecute(UrlsToWatch)) - { - return Task.CompletedTask; - } - - // set the initial values for the first request - if (_resetTime == DateTime.MinValue) - { - _resetTime = DateTime.Now.AddSeconds(_configuration.ResetTimeWindowSeconds); - } - if (_resourcesRemaining == -1) - { - _resourcesRemaining = _configuration.RateLimit; - } - - // see if we passed the reset time window - if (DateTime.Now > _resetTime) - { - _resourcesRemaining = _configuration.RateLimit; - _resetTime = DateTime.Now.AddSeconds(_configuration.ResetTimeWindowSeconds); - } - - // subtract the cost of the request - _resourcesRemaining -= _configuration.CostPerRequest; - if (_resourcesRemaining < 0) - { - _resourcesRemaining = 0; - var request = e.Session.HttpClient.Request; - - Logger.LogRequest([$"Exceeded resource limit when calling {request.Url}.", "Request will be throttled"], MessageType.Failed, new LoggingContext(e.Session)); - if (_configuration.WhenLimitExceeded == RateLimitResponseWhenLimitExceeded.Throttle) - { - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out object? value)) - { - value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); - } - - var throttledRequests = value as List; - throttledRequests?.Add(new ThrottlerInfo( - BuildThrottleKey(request), - ShouldThrottle, - _resetTime - )); - ThrottleResponse(e); - state.HasBeenSet = true; - } - else - { - if (_configuration.CustomResponse is not null) - { - var headersList = _configuration.CustomResponse.Headers is not null ? - _configuration.CustomResponse.Headers.Select(h => new HttpHeader(h.Name, h.Value)).ToList() : - []; - - var retryAfterHeader = headersList.FirstOrDefault(h => h.Name.Equals(_configuration.HeaderRetryAfter, StringComparison.OrdinalIgnoreCase)); - if (retryAfterHeader is not null && retryAfterHeader.Value == "@dynamic") - { - headersList.Add(new HttpHeader(_configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString())); - headersList.Remove(retryAfterHeader); - } - - var headers = headersList.ToArray(); - - // allow custom throttling response - var responseCode = (HttpStatusCode)(_configuration.CustomResponse.StatusCode ?? 200); - if (responseCode == HttpStatusCode.TooManyRequests) - { - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out object? value)) - { - value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); - } - - var throttledRequests = value as List; - throttledRequests?.Add(new ThrottlerInfo( - BuildThrottleKey(request), - ShouldThrottle, - _resetTime - )); - } - - string body = _configuration.CustomResponse.Body is not null ? - JsonSerializer.Serialize(_configuration.CustomResponse.Body, ProxyUtils.JsonSerializerOptions) : - ""; - e.Session.GenericResponse(body, responseCode, headers); - state.HasBeenSet = true; - } - else - { - Logger.LogRequest([$"Custom behavior not set. {_configuration.CustomResponseFile} not found."], MessageType.Failed, new LoggingContext(e.Session)); - } - } - } - - StoreRateLimitingHeaders(e); - return Task.CompletedTask; - } - - private void StoreRateLimitingHeaders(ProxyRequestArgs e) - { - // add rate limiting headers if reached the threshold percentage - if (_resourcesRemaining > _configuration.RateLimit - (_configuration.RateLimit * _configuration.WarningThresholdPercent / 100)) - { - return; - } - - var headers = new List(); - var reset = _configuration.ResetFormat == RateLimitResetFormat.SecondsLeft ? - (_resetTime - DateTime.Now).TotalSeconds.ToString("N0") : // drop decimals - new DateTimeOffset(_resetTime).ToUnixTimeSeconds().ToString(); - headers.AddRange( - [ - new(_configuration.HeaderLimit, _configuration.RateLimit.ToString()), - new(_configuration.HeaderRemaining, _resourcesRemaining.ToString()), - new(_configuration.HeaderReset, reset) - ]); - - ExposeRateLimitingForCors(headers, e); - - e.SessionData.Add(Name, headers); - } - - private void ExposeRateLimitingForCors(List headers, ProxyRequestArgs e) - { - var request = e.Session.HttpClient.Request; - if (request.Headers.FirstOrDefault((h) => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase)) is null) - { - return; - } - - headers.Add(new("Access-Control-Allow-Origin", "*")); - headers.Add(new("Access-Control-Expose-Headers", $"{_configuration.HeaderLimit}, {_configuration.HeaderRemaining}, {_configuration.HeaderReset}, {_configuration.HeaderRetryAfter}")); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using System.Net; +using System.Text.Json; +using System.Text.RegularExpressions; +using Titanium.Web.Proxy.Http; +using Titanium.Web.Proxy.Models; + +namespace Microsoft.DevProxy.Plugins.Behavior; + +public enum RateLimitResponseWhenLimitExceeded +{ + Throttle, + Custom +} + +public enum RateLimitResetFormat +{ + SecondsLeft, + UtcEpochSeconds +} + +public class RateLimitConfiguration +{ + public string HeaderLimit { get; set; } = "RateLimit-Limit"; + public string HeaderRemaining { get; set; } = "RateLimit-Remaining"; + public string HeaderReset { get; set; } = "RateLimit-Reset"; + public string HeaderRetryAfter { get; set; } = "Retry-After"; + public RateLimitResetFormat ResetFormat { get; set; } = RateLimitResetFormat.SecondsLeft; + public int CostPerRequest { get; set; } = 2; + public int ResetTimeWindowSeconds { get; set; } = 60; + public int WarningThresholdPercent { get; set; } = 80; + public int RateLimit { get; set; } = 120; + public RateLimitResponseWhenLimitExceeded WhenLimitExceeded { get; set; } = RateLimitResponseWhenLimitExceeded.Throttle; + public string CustomResponseFile { get; set; } = "rate-limit-response.json"; + public MockResponseResponse? CustomResponse { get; set; } +} + +public class RateLimitingPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(RateLimitingPlugin); + private readonly RateLimitConfiguration _configuration = new(); + // initial values so that we know when we intercept the + // first request and can set the initial values + private int _resourcesRemaining = -1; + private DateTime _resetTime = DateTime.MinValue; + private RateLimitingCustomResponseLoader? _loader = null; + + private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) + { + var throttleKeyForRequest = BuildThrottleKey(request); + return new ThrottlingInfo(throttleKeyForRequest == throttlingKey ? (int)(_resetTime - DateTime.Now).TotalSeconds : 0, _configuration.HeaderRetryAfter); + } + + private void ThrottleResponse(ProxyRequestArgs e) => UpdateProxyResponse(e, HttpStatusCode.TooManyRequests); + + private void UpdateProxyResponse(ProxyHttpEventArgsBase e, HttpStatusCode errorStatus) + { + var headers = new List(); + var body = string.Empty; + var request = e.Session.HttpClient.Request; + var response = e.Session.HttpClient.Response; + + // resources exceeded + if (errorStatus == HttpStatusCode.TooManyRequests) + { + if (ProxyUtils.IsGraphRequest(request)) + { + string requestId = Guid.NewGuid().ToString(); + string requestDate = DateTime.Now.ToString(); + headers.AddRange(ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate)); + + body = JsonSerializer.Serialize(new GraphErrorResponseBody( + new GraphErrorResponseError + { + Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + Message = BuildApiErrorMessage(request), + InnerError = new GraphErrorResponseInnerError + { + RequestId = requestId, + Date = requestDate + } + }), + ProxyUtils.JsonSerializerOptions + ); + } + + headers.Add(new(_configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString())); + if (request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) + { + headers.Add(new("Access-Control-Allow-Origin", "*")); + headers.Add(new("Access-Control-Expose-Headers", _configuration.HeaderRetryAfter)); + } + + e.Session.GenericResponse(body ?? string.Empty, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value)).ToArray()); + return; + } + + if (e.SessionData.TryGetValue(Name, out var pluginData) && + pluginData is List rateLimitingHeaders) + { + ProxyUtils.MergeHeaders(headers, rateLimitingHeaders); + } + + // add headers to the original API response, avoiding duplicates + headers.ForEach(h => e.Session.HttpClient.Response.Headers.RemoveHeader(h.Name)); + e.Session.HttpClient.Response.Headers.AddHeaders(headers.Select(h => new HttpHeader(h.Name, h.Value)).ToArray()); + } + private static string BuildApiErrorMessage(Request r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : String.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage(r)) : "")}"; + + private static string BuildThrottleKey(Request r) + { + if (ProxyUtils.IsGraphRequest(r)) + { + return GraphUtils.BuildThrottleKey(r); + } + else + { + return r.RequestUri.Host; + } + } + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + if (_configuration.WhenLimitExceeded == RateLimitResponseWhenLimitExceeded.Custom) + { + _configuration.CustomResponseFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.CustomResponseFile), Path.GetDirectoryName(Context.Configuration.ConfigFile ?? string.Empty) ?? string.Empty); + _loader = new RateLimitingCustomResponseLoader(Logger, _configuration); + // load the responses from the configured mocks file + _loader.InitResponsesWatcher(); + } + + PluginEvents.BeforeRequest += OnRequestAsync; + PluginEvents.BeforeResponse += OnResponseAsync; + } + + // add rate limiting headers to the response from the API + private Task OnResponseAsync(object? sender, ProxyResponseArgs e) + { + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + UpdateProxyResponse(e, HttpStatusCode.OK); + return Task.CompletedTask; + } + + private Task OnRequestAsync(object? sender, ProxyRequestArgs e) + { + var session = e.Session; + var state = e.ResponseState; + if (state.HasBeenSet) + { + Logger.LogRequest("Response already set", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (UrlsToWatch is null || + !e.ShouldExecute(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + // set the initial values for the first request + if (_resetTime == DateTime.MinValue) + { + _resetTime = DateTime.Now.AddSeconds(_configuration.ResetTimeWindowSeconds); + } + if (_resourcesRemaining == -1) + { + _resourcesRemaining = _configuration.RateLimit; + } + + // see if we passed the reset time window + if (DateTime.Now > _resetTime) + { + _resourcesRemaining = _configuration.RateLimit; + _resetTime = DateTime.Now.AddSeconds(_configuration.ResetTimeWindowSeconds); + } + + // subtract the cost of the request + _resourcesRemaining -= _configuration.CostPerRequest; + if (_resourcesRemaining < 0) + { + _resourcesRemaining = 0; + var request = e.Session.HttpClient.Request; + + Logger.LogRequest($"Exceeded resource limit when calling {request.Url}. Request will be throttled", MessageType.Failed, new LoggingContext(e.Session)); + if (_configuration.WhenLimitExceeded == RateLimitResponseWhenLimitExceeded.Throttle) + { + if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out object? value)) + { + value = new List(); + e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + } + + var throttledRequests = value as List; + throttledRequests?.Add(new ThrottlerInfo( + BuildThrottleKey(request), + ShouldThrottle, + _resetTime + )); + ThrottleResponse(e); + state.HasBeenSet = true; + } + else + { + if (_configuration.CustomResponse is not null) + { + var headersList = _configuration.CustomResponse.Headers is not null ? + _configuration.CustomResponse.Headers.Select(h => new HttpHeader(h.Name, h.Value)).ToList() : + []; + + var retryAfterHeader = headersList.FirstOrDefault(h => h.Name.Equals(_configuration.HeaderRetryAfter, StringComparison.OrdinalIgnoreCase)); + if (retryAfterHeader is not null && retryAfterHeader.Value == "@dynamic") + { + headersList.Add(new HttpHeader(_configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString())); + headersList.Remove(retryAfterHeader); + } + + var headers = headersList.ToArray(); + + // allow custom throttling response + var responseCode = (HttpStatusCode)(_configuration.CustomResponse.StatusCode ?? 200); + if (responseCode == HttpStatusCode.TooManyRequests) + { + if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out object? value)) + { + value = new List(); + e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + } + + var throttledRequests = value as List; + throttledRequests?.Add(new ThrottlerInfo( + BuildThrottleKey(request), + ShouldThrottle, + _resetTime + )); + } + + string body = _configuration.CustomResponse.Body is not null ? + JsonSerializer.Serialize(_configuration.CustomResponse.Body, ProxyUtils.JsonSerializerOptions) : + ""; + e.Session.GenericResponse(body, responseCode, headers); + state.HasBeenSet = true; + } + else + { + Logger.LogRequest($"Custom behavior not set. {_configuration.CustomResponseFile} not found.", MessageType.Failed, new LoggingContext(e.Session)); + } + } + } + else + { + Logger.LogRequest($"Resources remaining: {_resourcesRemaining}", MessageType.Skipped, new LoggingContext(e.Session)); + } + + StoreRateLimitingHeaders(e); + return Task.CompletedTask; + } + + private void StoreRateLimitingHeaders(ProxyRequestArgs e) + { + // add rate limiting headers if reached the threshold percentage + if (_resourcesRemaining > _configuration.RateLimit - (_configuration.RateLimit * _configuration.WarningThresholdPercent / 100)) + { + return; + } + + var headers = new List(); + var reset = _configuration.ResetFormat == RateLimitResetFormat.SecondsLeft ? + (_resetTime - DateTime.Now).TotalSeconds.ToString("N0") : // drop decimals + new DateTimeOffset(_resetTime).ToUnixTimeSeconds().ToString(); + headers.AddRange( + [ + new(_configuration.HeaderLimit, _configuration.RateLimit.ToString()), + new(_configuration.HeaderRemaining, _resourcesRemaining.ToString()), + new(_configuration.HeaderReset, reset) + ]); + + ExposeRateLimitingForCors(headers, e); + + e.SessionData.Add(Name, headers); + } + + private void ExposeRateLimitingForCors(List headers, ProxyRequestArgs e) + { + var request = e.Session.HttpClient.Request; + if (request.Headers.FirstOrDefault((h) => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase)) is null) + { + return; + } + + headers.Add(new("Access-Control-Allow-Origin", "*")); + headers.Add(new("Access-Control-Expose-Headers", $"{_configuration.HeaderLimit}, {_configuration.HeaderRemaining}, {_configuration.HeaderReset}, {_configuration.HeaderRetryAfter}")); + } +} diff --git a/dev-proxy-plugins/Behavior/RetryAfterPlugin.cs b/dev-proxy-plugins/Behavior/RetryAfterPlugin.cs index 550a8d3c..24b43ff6 100644 --- a/dev-proxy-plugins/Behavior/RetryAfterPlugin.cs +++ b/dev-proxy-plugins/Behavior/RetryAfterPlugin.cs @@ -1,124 +1,138 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using System.Net; -using System.Text.Json; -using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; - -namespace Microsoft.DevProxy.Plugins.Behavior; - -public class RetryAfterPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(RetryAfterPlugin); - public static readonly string ThrottledRequestsKey = "ThrottledRequests"; - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.BeforeRequest += OnRequestAsync; - } - - private Task OnRequestAsync(object? sender, ProxyRequestArgs e) - { - if (e.ResponseState.HasBeenSet || - UrlsToWatch is null || - string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase) || - !e.ShouldExecute(UrlsToWatch)) - { - return Task.CompletedTask; - } - - ThrottleIfNecessary(e); - return Task.CompletedTask; - } - - private void ThrottleIfNecessary(ProxyRequestArgs e) - { - var request = e.Session.HttpClient.Request; - if (!e.GlobalData.TryGetValue(ThrottledRequestsKey, out object? value)) - { - return; - } - - if (value is not List throttledRequests) - { - return; - } - - var expiredThrottlers = throttledRequests.Where(t => t.ResetTime < DateTime.Now).ToArray(); - foreach (var throttler in expiredThrottlers) - { - throttledRequests.Remove(throttler); - } - - if (throttledRequests.Count == 0) - { - return; - } - - foreach (var throttler in throttledRequests) - { - var throttleInfo = throttler.ShouldThrottle(request, throttler.ThrottlingKey); - if (throttleInfo.ThrottleForSeconds > 0) - { - var messageLines = new[] { $"Calling {request.Url} before waiting for the Retry-After period.", "Request will be throttled.", $"Throttling on {throttler.ThrottlingKey}." }; - Logger.LogRequest(messageLines, MessageType.Failed, new LoggingContext(e.Session)); - - throttler.ResetTime = DateTime.Now.AddSeconds(throttleInfo.ThrottleForSeconds); - UpdateProxyResponse(e, throttleInfo, string.Join(' ', messageLines)); - break; - } - } - } - - private static void UpdateProxyResponse(ProxyRequestArgs e, ThrottlingInfo throttlingInfo, string message) - { - var headers = new List(); - var body = string.Empty; - var request = e.Session.HttpClient.Request; - - // override the response body and headers for the error response - if (ProxyUtils.IsGraphRequest(request)) - { - string requestId = Guid.NewGuid().ToString(); - string requestDate = DateTime.Now.ToString(); - headers.AddRange(ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate)); - - body = JsonSerializer.Serialize(new GraphErrorResponseBody( - new GraphErrorResponseError - { - Code = new Regex("([A-Z])").Replace(HttpStatusCode.TooManyRequests.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), - Message = BuildApiErrorMessage(request, message), - InnerError = new GraphErrorResponseInnerError - { - RequestId = requestId, - Date = requestDate - } - }), - ProxyUtils.JsonSerializerOptions - ); - } - else - { - // ProxyUtils.BuildGraphResponseHeaders already includes CORS headers - if (request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) - { - headers.Add(new("Access-Control-Allow-Origin", "*")); - headers.Add(new("Access-Control-Expose-Headers", throttlingInfo.RetryAfterHeaderName)); - } - } - - headers.Add(new(throttlingInfo.RetryAfterHeaderName, throttlingInfo.ThrottleForSeconds.ToString())); - - e.Session.GenericResponse(body ?? string.Empty, HttpStatusCode.TooManyRequests, headers.Select(h => new HttpHeader(h.Name, h.Value))); - e.ResponseState.HasBeenSet = true; - } - - private static string BuildApiErrorMessage(Request r, string message) => $"{message} {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage(r)) : "")}"; -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using System.Net; +using System.Text.Json; +using System.Text.RegularExpressions; +using Titanium.Web.Proxy.Http; +using Titanium.Web.Proxy.Models; + +namespace Microsoft.DevProxy.Plugins.Behavior; + +public class RetryAfterPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(RetryAfterPlugin); + public static readonly string ThrottledRequestsKey = "ThrottledRequests"; + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.BeforeRequest += OnRequestAsync; + } + + private Task OnRequestAsync(object? sender, ProxyRequestArgs e) + { + if (e.ResponseState.HasBeenSet) + { + Logger.LogRequest("Response already set", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (UrlsToWatch is null || + !e.ShouldExecute(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + ThrottleIfNecessary(e); + return Task.CompletedTask; + } + + private void ThrottleIfNecessary(ProxyRequestArgs e) + { + var request = e.Session.HttpClient.Request; + if (!e.GlobalData.TryGetValue(ThrottledRequestsKey, out object? value)) + { + Logger.LogRequest("Request not throttled", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + if (value is not List throttledRequests) + { + Logger.LogRequest("Request not throttled", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + var expiredThrottlers = throttledRequests.Where(t => t.ResetTime < DateTime.Now).ToArray(); + foreach (var throttler in expiredThrottlers) + { + throttledRequests.Remove(throttler); + } + + if (throttledRequests.Count == 0) + { + Logger.LogRequest("Request not throttled", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + foreach (var throttler in throttledRequests) + { + var throttleInfo = throttler.ShouldThrottle(request, throttler.ThrottlingKey); + if (throttleInfo.ThrottleForSeconds > 0) + { + var message = $"Calling {request.Url} before waiting for the Retry-After period. Request will be throttled. Throttling on {throttler.ThrottlingKey}."; + Logger.LogRequest(message, MessageType.Failed, new LoggingContext(e.Session)); + + throttler.ResetTime = DateTime.Now.AddSeconds(throttleInfo.ThrottleForSeconds); + UpdateProxyResponse(e, throttleInfo, string.Join(' ', message)); + return; + } + } + + Logger.LogRequest("Request not throttled", MessageType.Skipped, new LoggingContext(e.Session)); + } + + private static void UpdateProxyResponse(ProxyRequestArgs e, ThrottlingInfo throttlingInfo, string message) + { + var headers = new List(); + var body = string.Empty; + var request = e.Session.HttpClient.Request; + + // override the response body and headers for the error response + if (ProxyUtils.IsGraphRequest(request)) + { + string requestId = Guid.NewGuid().ToString(); + string requestDate = DateTime.Now.ToString(); + headers.AddRange(ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate)); + + body = JsonSerializer.Serialize(new GraphErrorResponseBody( + new GraphErrorResponseError + { + Code = new Regex("([A-Z])").Replace(HttpStatusCode.TooManyRequests.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + Message = BuildApiErrorMessage(request, message), + InnerError = new GraphErrorResponseInnerError + { + RequestId = requestId, + Date = requestDate + } + }), + ProxyUtils.JsonSerializerOptions + ); + } + else + { + // ProxyUtils.BuildGraphResponseHeaders already includes CORS headers + if (request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) + { + headers.Add(new("Access-Control-Allow-Origin", "*")); + headers.Add(new("Access-Control-Expose-Headers", throttlingInfo.RetryAfterHeaderName)); + } + } + + headers.Add(new(throttlingInfo.RetryAfterHeaderName, throttlingInfo.ThrottleForSeconds.ToString())); + + e.Session.GenericResponse(body ?? string.Empty, HttpStatusCode.TooManyRequests, headers.Select(h => new HttpHeader(h.Name, h.Value))); + e.ResponseState.HasBeenSet = true; + } + + private static string BuildApiErrorMessage(Request r, string message) => $"{message} {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage(r)) : "")}"; +} diff --git a/dev-proxy-plugins/Guidance/CachingGuidancePlugin.cs b/dev-proxy-plugins/Guidance/CachingGuidancePlugin.cs index d44de2ca..ed7a3364 100644 --- a/dev-proxy-plugins/Guidance/CachingGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/CachingGuidancePlugin.cs @@ -1,66 +1,73 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using Titanium.Web.Proxy.Http; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -public class CachingGuidancePluginConfiguration -{ - public int CacheThresholdSeconds { get; set; } = 5; -} - -public class CachingGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(CachingGuidancePlugin); - private readonly CachingGuidancePluginConfiguration _configuration = new(); - private Dictionary _interceptedRequests = []; - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - PluginEvents.BeforeRequest += BeforeRequestAsync; - } - - private Task BeforeRequestAsync(object? sender, ProxyRequestArgs e) - { - if (UrlsToWatch is null || - !e.HasRequestUrlMatch(UrlsToWatch) || - string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) - { - return Task.CompletedTask; - } - - Request request = e.Session.HttpClient.Request; - var url = request.RequestUri.AbsoluteUri; - var now = DateTime.Now; - - if (!_interceptedRequests.TryGetValue(url, out DateTime value)) - { - value = now; - _interceptedRequests.Add(url, value); - return Task.CompletedTask; - } - - var lastIntercepted = value; - var secondsSinceLastIntercepted = (now - lastIntercepted).TotalSeconds; - if (secondsSinceLastIntercepted <= _configuration.CacheThresholdSeconds) - { - Logger.LogRequest(BuildCacheWarningMessage(request, _configuration.CacheThresholdSeconds, lastIntercepted), MessageType.Warning, new LoggingContext(e.Session)); - } - - _interceptedRequests[url] = now; - return Task.CompletedTask; - } - - private static string[] BuildCacheWarningMessage(Request r, int _warningSeconds, DateTime lastIntercepted) => [ - $"Another request to {r.RequestUri.PathAndQuery} intercepted within {_warningSeconds} seconds.", - $"Last intercepted at {lastIntercepted}.", - "Consider using cache to avoid calling the API too often." - ]; -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using Titanium.Web.Proxy.Http; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +public class CachingGuidancePluginConfiguration +{ + public int CacheThresholdSeconds { get; set; } = 5; +} + +public class CachingGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(CachingGuidancePlugin); + private readonly CachingGuidancePluginConfiguration _configuration = new(); + private Dictionary _interceptedRequests = []; + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + PluginEvents.BeforeRequest += BeforeRequestAsync; + } + + private Task BeforeRequestAsync(object? sender, ProxyRequestArgs e) + { + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + Request request = e.Session.HttpClient.Request; + var url = request.RequestUri.AbsoluteUri; + var now = DateTime.Now; + + if (!_interceptedRequests.TryGetValue(url, out DateTime value)) + { + value = now; + _interceptedRequests.Add(url, value); + Logger.LogRequest("First request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + var lastIntercepted = value; + var secondsSinceLastIntercepted = (now - lastIntercepted).TotalSeconds; + if (secondsSinceLastIntercepted <= _configuration.CacheThresholdSeconds) + { + Logger.LogRequest(BuildCacheWarningMessage(request, _configuration.CacheThresholdSeconds, lastIntercepted), MessageType.Warning, new LoggingContext(e.Session)); + } + else + { + Logger.LogRequest("Request outside of cache window", MessageType.Skipped, new LoggingContext(e.Session)); + } + + _interceptedRequests[url] = now; + return Task.CompletedTask; + } + + private static string BuildCacheWarningMessage(Request r, int _warningSeconds, DateTime lastIntercepted) => + $"Another request to {r.RequestUri.PathAndQuery} intercepted within {_warningSeconds} seconds. Last intercepted at {lastIntercepted}. Consider using cache to avoid calling the API too often."; +} diff --git a/dev-proxy-plugins/Guidance/GraphBetaSupportGuidancePlugin.cs b/dev-proxy-plugins/Guidance/GraphBetaSupportGuidancePlugin.cs index 45980c1f..c6726ba7 100644 --- a/dev-proxy-plugins/Guidance/GraphBetaSupportGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/GraphBetaSupportGuidancePlugin.cs @@ -1,38 +1,49 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using Titanium.Web.Proxy.Http; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -public class GraphBetaSupportGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(GraphBetaSupportGuidancePlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.AfterResponse += AfterResponseAsync; - } - - private Task AfterResponseAsync(object? sender, ProxyResponseArgs e) - { - Request request = e.Session.HttpClient.Request; - if (UrlsToWatch is not null && - e.HasRequestUrlMatch(UrlsToWatch) && - !string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase) && - ProxyUtils.IsGraphBetaRequest(request)) - Logger.LogRequest(BuildBetaSupportMessage(), MessageType.Warning, new LoggingContext(e.Session)); - return Task.CompletedTask; - } - - private static string GetBetaSupportGuidanceUrl() => "https://aka.ms/devproxy/guidance/beta-support"; - private static string[] BuildBetaSupportMessage() - { - return [$"Don't use beta APIs in production because they can change or be deprecated.", $"More info at {GetBetaSupportGuidanceUrl()}"]; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using Titanium.Web.Proxy.Http; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +public class GraphBetaSupportGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(GraphBetaSupportGuidancePlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.AfterResponse += AfterResponseAsync; + } + + private Task AfterResponseAsync(object? sender, ProxyResponseArgs e) + { + Request request = e.Session.HttpClient.Request; + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (!ProxyUtils.IsGraphBetaRequest(request)) + { + Logger.LogRequest("Not a Microsoft Graph beta request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + Logger.LogRequest(BuildBetaSupportMessage(), MessageType.Warning, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + private static string GetBetaSupportGuidanceUrl() => "https://aka.ms/devproxy/guidance/beta-support"; + private static string BuildBetaSupportMessage() => + $"Don't use beta APIs in production because they can change or be deprecated. More info at {GetBetaSupportGuidanceUrl()}"; +} diff --git a/dev-proxy-plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs b/dev-proxy-plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs index 1ced5fe1..45c78ae6 100644 --- a/dev-proxy-plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs @@ -1,51 +1,61 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using Titanium.Web.Proxy.Http; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -public class GraphClientRequestIdGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(GraphClientRequestIdGuidancePlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.BeforeRequest += BeforeRequestAsync; - } - - private Task BeforeRequestAsync(object? sender, ProxyRequestArgs e) - { - Request request = e.Session.HttpClient.Request; - if (UrlsToWatch is not null && - e.HasRequestUrlMatch(UrlsToWatch) && - !string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase) && - WarnNoClientRequestId(request)) - { - Logger.LogRequest(BuildAddClientRequestIdMessage(), MessageType.Warning, new LoggingContext(e.Session)); - - if (!ProxyUtils.IsSdkRequest(request)) - { - Logger.LogRequest(MessageUtils.BuildUseSdkMessage(request), MessageType.Tip, new LoggingContext(e.Session)); - } - } - - return Task.CompletedTask; - } - - private static bool WarnNoClientRequestId(Request request) => - ProxyUtils.IsGraphRequest(request) && - !request.Headers.HeaderExists("client-request-id"); - - private static string GetClientRequestIdGuidanceUrl() => "https://aka.ms/devproxy/guidance/client-request-id"; - private static string[] BuildAddClientRequestIdMessage() => [ - $"To help Microsoft investigate errors, to each request to Microsoft Graph", - "add the client-request-id header with a unique GUID.", - $"More info at {GetClientRequestIdGuidanceUrl()}" - ]; -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using Titanium.Web.Proxy.Http; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +public class GraphClientRequestIdGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(GraphClientRequestIdGuidancePlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.BeforeRequest += BeforeRequestAsync; + } + + private Task BeforeRequestAsync(object? sender, ProxyRequestArgs e) + { + Request request = e.Session.HttpClient.Request; + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + if (WarnNoClientRequestId(request)) + { + Logger.LogRequest(BuildAddClientRequestIdMessage(), MessageType.Warning, new LoggingContext(e.Session)); + + if (!ProxyUtils.IsSdkRequest(request)) + { + Logger.LogRequest(MessageUtils.BuildUseSdkMessage(request), MessageType.Tip, new LoggingContext(e.Session)); + } + } + else + { + Logger.LogRequest("client-request-id header present", MessageType.Skipped, new LoggingContext(e.Session)); + } + + return Task.CompletedTask; + } + + private static bool WarnNoClientRequestId(Request request) => + ProxyUtils.IsGraphRequest(request) && + !request.Headers.HeaderExists("client-request-id"); + + private static string GetClientRequestIdGuidanceUrl() => "https://aka.ms/devproxy/guidance/client-request-id"; + private static string BuildAddClientRequestIdMessage() => + $"To help Microsoft investigate errors, to each request to Microsoft Graph add the client-request-id header with a unique GUID. More info at {GetClientRequestIdGuidanceUrl()}"; +} diff --git a/dev-proxy-plugins/Guidance/GraphConnectorGuidancePlugin.cs b/dev-proxy-plugins/Guidance/GraphConnectorGuidancePlugin.cs index b6f8a1fb..34da3574 100644 --- a/dev-proxy-plugins/Guidance/GraphConnectorGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/GraphConnectorGuidancePlugin.cs @@ -1,112 +1,117 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using System.Text.Json; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -class ExternalConnectionSchema -{ - public string? BaseType { get; set; } - public ExternalConnectionSchemaProperty[]? Properties { get; set; } -} - -class ExternalConnectionSchemaProperty -{ - public string[]? Aliases { get; set; } - public bool? IsQueryable { get; set; } - public bool? IsRefinable { get; set; } - public bool? IsRetrievable { get; set; } - public bool? IsSearchable { get; set; } - public string[]? Labels { get; set; } - public string? Name { get; set; } - public string? Type { get; set; } -} - -public class GraphConnectorGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(GraphConnectorGuidancePlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.BeforeRequest += BeforeRequestAsync; - } - - private Task BeforeRequestAsync(object sender, ProxyRequestArgs e) - { - if (UrlsToWatch is null || - !e.HasRequestUrlMatch(UrlsToWatch) || - !string.Equals(e.Session.HttpClient.Request.Method, "PATCH", StringComparison.OrdinalIgnoreCase)) - { - return Task.CompletedTask; - } - - try - { - var schemaString = e.Session.HttpClient.Request.BodyString; - if (string.IsNullOrEmpty(schemaString)) - { - Logger.LogRequest([ "No schema found in the request body." ], MessageType.Failed, new LoggingContext(e.Session)); - return Task.CompletedTask; - } - - var schema = JsonSerializer.Deserialize(schemaString, ProxyUtils.JsonSerializerOptions); - if (schema is null || schema.Properties is null) - { - Logger.LogRequest([ "Invalid schema found in the request body." ], MessageType.Failed, new LoggingContext(e.Session)); - return Task.CompletedTask; - } - - bool hasTitle = false, hasIconUrl = false, hasUrl = false; - foreach (var property in schema.Properties) - { - if (property.Labels is null) - { - continue; - } - - if (property.Labels.Contains("title", StringComparer.OrdinalIgnoreCase)) - { - hasTitle = true; - } - if (property.Labels.Contains("iconUrl", StringComparer.OrdinalIgnoreCase)) - { - hasIconUrl = true; - } - if (property.Labels.Contains("url", StringComparer.OrdinalIgnoreCase)) - { - hasUrl = true; - } - } - - if (!hasTitle || !hasIconUrl || !hasUrl) - { - string[] missingLabels = [ - !hasTitle ? "title" : "", - !hasIconUrl ? "iconUrl" : "", - !hasUrl ? "url" : "" - ]; - - Logger.LogRequest( - [ - $"The schema is missing the following semantic labels: {string.Join(", ", missingLabels.Where(s => s != ""))}.", - "Ingested content might not show up in Microsoft Copilot for Microsoft 365.", - "More information: https://aka.ms/devproxy/guidance/gc/ux" - ], - MessageType.Failed, new LoggingContext(e.Session) - ); - } - } - catch (Exception ex) - { - Logger.LogError(ex, "An error has occurred while deserializing the request body"); - } - - return Task.CompletedTask; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using System.Text.Json; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +class ExternalConnectionSchema +{ + public string? BaseType { get; set; } + public ExternalConnectionSchemaProperty[]? Properties { get; set; } +} + +class ExternalConnectionSchemaProperty +{ + public string[]? Aliases { get; set; } + public bool? IsQueryable { get; set; } + public bool? IsRefinable { get; set; } + public bool? IsRetrievable { get; set; } + public bool? IsSearchable { get; set; } + public string[]? Labels { get; set; } + public string? Name { get; set; } + public string? Type { get; set; } +} + +public class GraphConnectorGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(GraphConnectorGuidancePlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.BeforeRequest += BeforeRequestAsync; + } + + private Task BeforeRequestAsync(object sender, ProxyRequestArgs e) + { + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (!string.Equals(e.Session.HttpClient.Request.Method, "PATCH", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping non-PATCH request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + try + { + var schemaString = e.Session.HttpClient.Request.BodyString; + if (string.IsNullOrEmpty(schemaString)) + { + Logger.LogRequest("No schema found in the request body.", MessageType.Failed, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + var schema = JsonSerializer.Deserialize(schemaString, ProxyUtils.JsonSerializerOptions); + if (schema is null || schema.Properties is null) + { + Logger.LogRequest("Invalid schema found in the request body.", MessageType.Failed, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + bool hasTitle = false, hasIconUrl = false, hasUrl = false; + foreach (var property in schema.Properties) + { + if (property.Labels is null) + { + continue; + } + + if (property.Labels.Contains("title", StringComparer.OrdinalIgnoreCase)) + { + hasTitle = true; + } + if (property.Labels.Contains("iconUrl", StringComparer.OrdinalIgnoreCase)) + { + hasIconUrl = true; + } + if (property.Labels.Contains("url", StringComparer.OrdinalIgnoreCase)) + { + hasUrl = true; + } + } + + if (!hasTitle || !hasIconUrl || !hasUrl) + { + string[] missingLabels = [ + !hasTitle ? "title" : "", + !hasIconUrl ? "iconUrl" : "", + !hasUrl ? "url" : "" + ]; + + Logger.LogRequest( + $"The schema is missing the following semantic labels: {string.Join(", ", missingLabels.Where(s => s != ""))}. Ingested content might not show up in Microsoft Copilot for Microsoft 365. More information: https://aka.ms/devproxy/guidance/gc/ux", + MessageType.Failed, new LoggingContext(e.Session) + ); + } + else + { + Logger.LogRequest("The schema contains all the required semantic labels.", MessageType.Skipped, new LoggingContext(e.Session)); + } + } + catch (Exception ex) + { + Logger.LogError(ex, "An error has occurred while deserializing the request body"); + } + + return Task.CompletedTask; + } +} diff --git a/dev-proxy-plugins/Guidance/GraphSdkGuidancePlugin.cs b/dev-proxy-plugins/Guidance/GraphSdkGuidancePlugin.cs index 7a44c04a..d27e93f5 100644 --- a/dev-proxy-plugins/Guidance/GraphSdkGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/GraphSdkGuidancePlugin.cs @@ -1,39 +1,58 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using Titanium.Web.Proxy.Http; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -public class GraphSdkGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(GraphSdkGuidancePlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.AfterResponse += OnAfterResponseAsync; - } - - private Task OnAfterResponseAsync(object? sender, ProxyResponseArgs e) - { - Request request = e.Session.HttpClient.Request; - // only show the message if there is an error. - if (e.Session.HttpClient.Response.StatusCode >= 400 && - UrlsToWatch is not null && - e.HasRequestUrlMatch(UrlsToWatch) && - !string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase) && - WarnNoSdk(request)) - { - Logger.LogRequest(MessageUtils.BuildUseSdkForErrorsMessage(request), MessageType.Tip, new LoggingContext(e.Session)); - } - - return Task.CompletedTask; - } - - private static bool WarnNoSdk(Request request) => ProxyUtils.IsGraphRequest(request) && !ProxyUtils.IsSdkRequest(request); -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using Titanium.Web.Proxy.Http; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +public class GraphSdkGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(GraphSdkGuidancePlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.AfterResponse += OnAfterResponseAsync; + } + + private Task OnAfterResponseAsync(object? sender, ProxyResponseArgs e) + { + Request request = e.Session.HttpClient.Request; + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + // only show the message if there is an error. + if (e.Session.HttpClient.Response.StatusCode >= 400) + { + if (WarnNoSdk(request)) + { + Logger.LogRequest(MessageUtils.BuildUseSdkForErrorsMessage(request), MessageType.Tip, new LoggingContext(e.Session)); + } + else + { + Logger.LogRequest("Request issued using SDK", MessageType.Skipped, new LoggingContext(e.Session)); + } + } + else + { + Logger.LogRequest("Skipping non-error response", MessageType.Skipped, new LoggingContext(e.Session)); + } + + return Task.CompletedTask; + } + + private static bool WarnNoSdk(Request request) => ProxyUtils.IsGraphRequest(request) && !ProxyUtils.IsSdkRequest(request); +} diff --git a/dev-proxy-plugins/Guidance/GraphSelectGuidancePlugin.cs b/dev-proxy-plugins/Guidance/GraphSelectGuidancePlugin.cs index a36e44b0..fb8db8cb 100644 --- a/dev-proxy-plugins/Guidance/GraphSelectGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/GraphSelectGuidancePlugin.cs @@ -1,90 +1,105 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using Titanium.Web.Proxy.Http; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -public class GraphSelectGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(GraphSelectGuidancePlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.AfterResponse += AfterResponseAsync; - - // let's not await so that it doesn't block the proxy startup - _ = MSGraphDbUtils.GenerateMSGraphDbAsync(Logger, true); - } - - private Task AfterResponseAsync(object? sender, ProxyResponseArgs e) - { - Request request = e.Session.HttpClient.Request; - if (UrlsToWatch is not null && - e.HasRequestUrlMatch(UrlsToWatch) && - !string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase) && - WarnNoSelect(request)) - Logger.LogRequest(BuildUseSelectMessage(), MessageType.Warning, new LoggingContext(e.Session)); - - return Task.CompletedTask; - } - - private bool WarnNoSelect(Request request) - { - if (!ProxyUtils.IsGraphRequest(request) || - request.Method != "GET") - { - return false; - } - - var graphVersion = ProxyUtils.GetGraphVersion(request.RequestUri.AbsoluteUri); - var tokenizedUrl = GetTokenizedUrl(request.RequestUri.AbsoluteUri); - - if (EndpointSupportsSelect(graphVersion, tokenizedUrl)) - { - return !request.Url.Contains("$select", StringComparison.OrdinalIgnoreCase) && - !request.Url.Contains("%24select", StringComparison.OrdinalIgnoreCase); - } - else - { - return false; - } - } - - private bool EndpointSupportsSelect(string graphVersion, string relativeUrl) - { - var fallback = relativeUrl.Contains("$value", StringComparison.OrdinalIgnoreCase); - - try - { - var dbConnection = MSGraphDbUtils.MSGraphDbConnection; - // lookup information from the database - var selectEndpoint = dbConnection.CreateCommand(); - selectEndpoint.CommandText = "SELECT hasSelect FROM endpoints WHERE path = @path AND graphVersion = @graphVersion"; - selectEndpoint.Parameters.AddWithValue("@path", relativeUrl); - selectEndpoint.Parameters.AddWithValue("@graphVersion", graphVersion); - var result = selectEndpoint.ExecuteScalar(); - var hasSelect = result != null && Convert.ToInt32(result) == 1; - return hasSelect; - } - catch (Exception ex) - { - Logger.LogError(ex, "Error looking up endpoint in database"); - return fallback; - } - } - - private static string GetSelectParameterGuidanceUrl() => "https://aka.ms/devproxy/guidance/select"; - private static string[] BuildUseSelectMessage() => [$"To improve performance of your application, use the $select parameter.", $"More info at {GetSelectParameterGuidanceUrl()}"]; - - private static string GetTokenizedUrl(string absoluteUrl) - { - var sanitizedUrl = ProxyUtils.SanitizeUrl(absoluteUrl); - return "/" + string.Join("", new Uri(sanitizedUrl).Segments.Skip(2).Select(Uri.UnescapeDataString)); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using Titanium.Web.Proxy.Http; +using Microsoft.Extensions.Logging; +using Titanium.Web.Proxy.EventArguments; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +public class GraphSelectGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(GraphSelectGuidancePlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.AfterResponse += AfterResponseAsync; + + // let's not await so that it doesn't block the proxy startup + _ = MSGraphDbUtils.GenerateMSGraphDbAsync(Logger, true); + } + + private Task AfterResponseAsync(object? sender, ProxyResponseArgs e) + { + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + if (WarnNoSelect(e.Session)) + { + Logger.LogRequest(BuildUseSelectMessage(), MessageType.Warning, new LoggingContext(e.Session)); + } + + return Task.CompletedTask; + } + + private bool WarnNoSelect(SessionEventArgs session) + { + var request = session.HttpClient.Request; + if (!ProxyUtils.IsGraphRequest(request) || + request.Method != "GET") + { + Logger.LogRequest("Not a Microsoft Graph GET request", MessageType.Skipped, new LoggingContext(session)); + return false; + } + + var graphVersion = ProxyUtils.GetGraphVersion(request.RequestUri.AbsoluteUri); + var tokenizedUrl = GetTokenizedUrl(request.RequestUri.AbsoluteUri); + + if (EndpointSupportsSelect(graphVersion, tokenizedUrl)) + { + return !request.Url.Contains("$select", StringComparison.OrdinalIgnoreCase) && + !request.Url.Contains("%24select", StringComparison.OrdinalIgnoreCase); + } + else + { + Logger.LogRequest("Endpoint does not support $select", MessageType.Skipped, new LoggingContext(session)); + return false; + } + } + + private bool EndpointSupportsSelect(string graphVersion, string relativeUrl) + { + var fallback = relativeUrl.Contains("$value", StringComparison.OrdinalIgnoreCase); + + try + { + var dbConnection = MSGraphDbUtils.MSGraphDbConnection; + // lookup information from the database + var selectEndpoint = dbConnection.CreateCommand(); + selectEndpoint.CommandText = "SELECT hasSelect FROM endpoints WHERE path = @path AND graphVersion = @graphVersion"; + selectEndpoint.Parameters.AddWithValue("@path", relativeUrl); + selectEndpoint.Parameters.AddWithValue("@graphVersion", graphVersion); + var result = selectEndpoint.ExecuteScalar(); + var hasSelect = result != null && Convert.ToInt32(result) == 1; + return hasSelect; + } + catch (Exception ex) + { + Logger.LogError(ex, "Error looking up endpoint in database"); + return fallback; + } + } + + private static string GetSelectParameterGuidanceUrl() => "https://aka.ms/devproxy/guidance/select"; + private static string BuildUseSelectMessage() => + $"To improve performance of your application, use the $select parameter. More info at {GetSelectParameterGuidanceUrl()}"; + + private static string GetTokenizedUrl(string absoluteUrl) + { + var sanitizedUrl = ProxyUtils.SanitizeUrl(absoluteUrl); + return "/" + string.Join("", new Uri(sanitizedUrl).Segments.Skip(2).Select(Uri.UnescapeDataString)); + } +} diff --git a/dev-proxy-plugins/Guidance/ODSPSearchGuidancePlugin.cs b/dev-proxy-plugins/Guidance/ODSPSearchGuidancePlugin.cs index 2bbee42d..9debf8ab 100644 --- a/dev-proxy-plugins/Guidance/ODSPSearchGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/ODSPSearchGuidancePlugin.cs @@ -1,61 +1,76 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using Titanium.Web.Proxy.Http; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -public class ODSPSearchGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(ODSPSearchGuidancePlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.BeforeRequest += BeforeRequestAsync; - } - - private Task BeforeRequestAsync(object sender, ProxyRequestArgs e) - { - Request request = e.Session.HttpClient.Request; - if (UrlsToWatch is not null && - e.HasRequestUrlMatch(UrlsToWatch) && - !string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase) && - WarnDeprecatedSearch(request)) - Logger.LogRequest(BuildUseGraphSearchMessage(), MessageType.Warning, new LoggingContext(e.Session)); - - return Task.CompletedTask; - } - - private static bool WarnDeprecatedSearch(Request request) - { - if (!ProxyUtils.IsGraphRequest(request) || - request.Method != "GET") - { - return false; - } - - // graph.microsoft.com/{version}/drives/{drive-id}/root/search(q='{search-text}') - // graph.microsoft.com/{version}/groups/{group-id}/drive/root/search(q='{search-text}') - // graph.microsoft.com/{version}/me/drive/root/search(q='{search-text}') - // graph.microsoft.com/{version}/sites/{site-id}/drive/root/search(q='{search-text}') - // graph.microsoft.com/{version}/users/{user-id}/drive/root/search(q='{search-text}') - // graph.microsoft.com/{version}/sites?search={query} - if (request.RequestUri.AbsolutePath.Contains("/search(q=", StringComparison.OrdinalIgnoreCase) || - (request.RequestUri.AbsolutePath.EndsWith("/sites", StringComparison.OrdinalIgnoreCase) && - request.RequestUri.Query.Contains("search=", StringComparison.OrdinalIgnoreCase))) - { - return true; - } - else - { - return false; - } - } - - private static string[] BuildUseGraphSearchMessage() => [$"To get the best search experience, use the Microsoft Search APIs in Microsoft Graph.", $"More info at https://aka.ms/devproxy/guidance/odspsearch"]; -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using Titanium.Web.Proxy.Http; +using Titanium.Web.Proxy.EventArguments; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +public class ODSPSearchGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(ODSPSearchGuidancePlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.BeforeRequest += BeforeRequestAsync; + } + + private Task BeforeRequestAsync(object sender, ProxyRequestArgs e) + { + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + if (WarnDeprecatedSearch(e.Session)) + { + Logger.LogRequest(BuildUseGraphSearchMessage(), MessageType.Warning, new LoggingContext(e.Session)); + } + + return Task.CompletedTask; + } + + private bool WarnDeprecatedSearch(SessionEventArgs session) + { + Request request = session.HttpClient.Request; + if (!ProxyUtils.IsGraphRequest(request) || + request.Method != "GET") + { + Logger.LogRequest("Not a Microsoft Graph GET request", MessageType.Skipped, new LoggingContext(session)); + return false; + } + + // graph.microsoft.com/{version}/drives/{drive-id}/root/search(q='{search-text}') + // graph.microsoft.com/{version}/groups/{group-id}/drive/root/search(q='{search-text}') + // graph.microsoft.com/{version}/me/drive/root/search(q='{search-text}') + // graph.microsoft.com/{version}/sites/{site-id}/drive/root/search(q='{search-text}') + // graph.microsoft.com/{version}/users/{user-id}/drive/root/search(q='{search-text}') + // graph.microsoft.com/{version}/sites?search={query} + if (request.RequestUri.AbsolutePath.Contains("/search(q=", StringComparison.OrdinalIgnoreCase) || + (request.RequestUri.AbsolutePath.EndsWith("/sites", StringComparison.OrdinalIgnoreCase) && + request.RequestUri.Query.Contains("search=", StringComparison.OrdinalIgnoreCase))) + { + return true; + } + else + { + Logger.LogRequest("Not a SharePoint search request", MessageType.Skipped, new LoggingContext(session)); + return false; + } + } + + private static string BuildUseGraphSearchMessage() => + $"To get the best search experience, use the Microsoft Search APIs in Microsoft Graph. More info at https://aka.ms/devproxy/guidance/odspsearch"; +} diff --git a/dev-proxy-plugins/Guidance/ODataPagingGuidancePlugin.cs b/dev-proxy-plugins/Guidance/ODataPagingGuidancePlugin.cs index 1b173c81..6484ac41 100644 --- a/dev-proxy-plugins/Guidance/ODataPagingGuidancePlugin.cs +++ b/dev-proxy-plugins/Guidance/ODataPagingGuidancePlugin.cs @@ -1,134 +1,164 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using System.Text.Json; -using System.Xml.Linq; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.Guidance; - -public class ODataPagingGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(ODataPagingGuidancePlugin); - private readonly IList pagingUrls = []; - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.BeforeRequest += OnBeforeRequestAsync; - PluginEvents.BeforeResponse += OnBeforeResponseAsync; - } - - private Task OnBeforeRequestAsync(object? sender, ProxyRequestArgs e) - { - if (UrlsToWatch is null || - e.Session.HttpClient.Request.Method != "GET" || - !e.HasRequestUrlMatch(UrlsToWatch)) - { - return Task.CompletedTask; - } - - if (IsODataPagingUrl(e.Session.HttpClient.Request.RequestUri) && - !pagingUrls.Contains(e.Session.HttpClient.Request.Url)) - { - Logger.LogRequest(BuildIncorrectPagingUrlMessage(), MessageType.Warning, new LoggingContext(e.Session)); - } - - return Task.CompletedTask; - } - - private async Task OnBeforeResponseAsync(object? sender, ProxyResponseArgs e) - { - if (UrlsToWatch is null || - !e.HasRequestUrlMatch(UrlsToWatch) || - e.Session.HttpClient.Request.Method != "GET" || - e.Session.HttpClient.Response.StatusCode >= 300 || - e.Session.HttpClient.Response.ContentType is null || - (!e.Session.HttpClient.Response.ContentType.Contains("json") && - !e.Session.HttpClient.Response.ContentType.Contains("application/atom+xml")) || - !e.Session.HttpClient.Response.HasBody) - { - return; - } - - e.Session.HttpClient.Response.KeepBody = true; - - var nextLink = string.Empty; - var bodyString = await e.Session.GetResponseBodyAsString(); - if (string.IsNullOrEmpty(bodyString)) - { - return; - } - - var contentType = e.Session.HttpClient.Response.ContentType; - if (contentType.Contains("json")) - { - nextLink = GetNextLinkFromJson(bodyString); - } - else if (contentType.Contains("application/atom+xml")) - { - nextLink = GetNextLinkFromXml(bodyString); - } - - if (!string.IsNullOrEmpty(nextLink)) - { - pagingUrls.Add(nextLink); - } - } - - private string GetNextLinkFromJson(string responseBody) - { - var nextLink = string.Empty; - - try - { - var response = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); - if (response.TryGetProperty("@odata.nextLink", out var nextLinkProperty)) - { - nextLink = nextLinkProperty.GetString() ?? string.Empty; - } - } - catch (Exception e) - { - Logger.LogDebug(e, "An error has occurred while parsing the response body"); - } - - return nextLink; - } - - private string GetNextLinkFromXml(string responseBody) - { - var nextLink = string.Empty; - - try - { - var doc = XDocument.Parse(responseBody); - nextLink = doc - .Descendants() - .Where(e => e.Name.LocalName == "link" && e.Attribute("rel")?.Value == "next") - .FirstOrDefault() - ?.Attribute("href")?.Value ?? string.Empty; - } - catch (Exception e) - { - Logger.LogError(e.Message); - } - - return nextLink; - } - - private static bool IsODataPagingUrl(Uri uri) => - uri.Query.Contains("$skip") || - uri.Query.Contains("%24skip") || - uri.Query.Contains("$skiptoken") || - uri.Query.Contains("%24skiptoken"); - - private static string[] BuildIncorrectPagingUrlMessage() => [ - "This paging URL seems to be created manually and is not aligned with paging information from the API.", - "This could lead to incorrect data in your app.", - "For more information about paging see https://aka.ms/devproxy/guidance/paging" - ]; -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using System.Text.Json; +using System.Xml.Linq; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.Guidance; + +public class ODataPagingGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(ODataPagingGuidancePlugin); + private readonly IList pagingUrls = []; + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.BeforeRequest += OnBeforeRequestAsync; + PluginEvents.BeforeResponse += OnBeforeResponseAsync; + } + + private Task OnBeforeRequestAsync(object? sender, ProxyRequestArgs e) + { + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (!string.Equals(e.Session.HttpClient.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping non-GET request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + if (IsODataPagingUrl(e.Session.HttpClient.Request.RequestUri)) + { + if (!pagingUrls.Contains(e.Session.HttpClient.Request.Url)) + { + Logger.LogRequest(BuildIncorrectPagingUrlMessage(), MessageType.Warning, new LoggingContext(e.Session)); + } + else + { + Logger.LogRequest("Paging URL is correct", MessageType.Skipped, new LoggingContext(e.Session)); + } + } + else + { + Logger.LogRequest("Not an OData paging URL", MessageType.Skipped, new LoggingContext(e.Session)); + } + + return Task.CompletedTask; + } + + private async Task OnBeforeResponseAsync(object? sender, ProxyResponseArgs e) + { + if (UrlsToWatch is null || + !e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + if (!string.Equals(e.Session.HttpClient.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogRequest("Skipping non-GET request", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + if (e.Session.HttpClient.Response.StatusCode >= 300) + { + Logger.LogRequest("Skipping non-success response", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + if (e.Session.HttpClient.Response.ContentType is null || + (!e.Session.HttpClient.Response.ContentType.Contains("json") && + !e.Session.HttpClient.Response.ContentType.Contains("application/atom+xml")) || + !e.Session.HttpClient.Response.HasBody) + { + Logger.LogRequest("Skipping response with unsupported body type", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + e.Session.HttpClient.Response.KeepBody = true; + + var nextLink = string.Empty; + var bodyString = await e.Session.GetResponseBodyAsString(); + if (string.IsNullOrEmpty(bodyString)) + { + Logger.LogRequest("Skipping empty response body", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + var contentType = e.Session.HttpClient.Response.ContentType; + if (contentType.Contains("json")) + { + nextLink = GetNextLinkFromJson(bodyString); + } + else if (contentType.Contains("application/atom+xml")) + { + nextLink = GetNextLinkFromXml(bodyString); + } + + if (!string.IsNullOrEmpty(nextLink)) + { + pagingUrls.Add(nextLink); + } + else + { + Logger.LogRequest("No next link found in the response", MessageType.Skipped, new LoggingContext(e.Session)); + } + } + + private string GetNextLinkFromJson(string responseBody) + { + var nextLink = string.Empty; + + try + { + var response = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); + if (response.TryGetProperty("@odata.nextLink", out var nextLinkProperty)) + { + nextLink = nextLinkProperty.GetString() ?? string.Empty; + } + } + catch (Exception e) + { + Logger.LogDebug(e, "An error has occurred while parsing the response body"); + } + + return nextLink; + } + + private string GetNextLinkFromXml(string responseBody) + { + var nextLink = string.Empty; + + try + { + var doc = XDocument.Parse(responseBody); + nextLink = doc + .Descendants() + .Where(e => e.Name.LocalName == "link" && e.Attribute("rel")?.Value == "next") + .FirstOrDefault() + ?.Attribute("href")?.Value ?? string.Empty; + } + catch (Exception e) + { + Logger.LogError(e.Message); + } + + return nextLink; + } + + private static bool IsODataPagingUrl(Uri uri) => + uri.Query.Contains("$skip") || + uri.Query.Contains("%24skip") || + uri.Query.Contains("$skiptoken") || + uri.Query.Contains("%24skiptoken"); + + private static string BuildIncorrectPagingUrlMessage() => + "This paging URL seems to be created manually and is not aligned with paging information from the API. This could lead to incorrect data in your app. For more information about paging see https://aka.ms/devproxy/guidance/paging"; +} diff --git a/dev-proxy-plugins/Inspection/DevToolsPlugin.cs b/dev-proxy-plugins/Inspection/DevToolsPlugin.cs index 95ea199b..94845b50 100644 --- a/dev-proxy-plugins/Inspection/DevToolsPlugin.cs +++ b/dev-proxy-plugins/Inspection/DevToolsPlugin.cs @@ -1,381 +1,381 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics; -using System.Net; -using System.Net.Sockets; -using System.Runtime.InteropServices; -using System.Text.Json; -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Plugins.Inspection.CDP; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.Inspection; - -public enum PreferredBrowser -{ - Edge, - Chrome, - EdgeDev, - EdgeBeta -} - -public class DevToolsPluginConfiguration -{ - public PreferredBrowser PreferredBrowser { get; set; } = PreferredBrowser.Edge; - - /// - /// Path to the browser executable. If not set, the plugin will try to find - /// the browser executable based on the PreferredBrowser. - /// - /// Use this value when you install the browser in a non-standard - /// path. - public string PreferredBrowserPath { get; set; } = string.Empty; -} - -public class DevToolsPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private WebSocketServer? webSocket; - private readonly Dictionary responseBody = []; - - public override string Name => nameof(DevToolsPlugin); - private readonly DevToolsPluginConfiguration _configuration = new(); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - ConfigSection?.Bind(_configuration); - - InitInspector(); - - PluginEvents.BeforeRequest += BeforeRequestAsync; - PluginEvents.AfterResponse += AfterResponseAsync; - PluginEvents.AfterRequestLog += AfterRequestLogAsync; - } - - private static int GetFreePort() - { - using var listener = new TcpListener(IPAddress.Loopback, 0); - listener.Start(); - var port = ((IPEndPoint)listener.LocalEndpoint).Port; - listener.Stop(); - return port; - } - - private string GetBrowserPath(DevToolsPluginConfiguration configuration) - { - if (!string.IsNullOrEmpty(configuration.PreferredBrowserPath)) - { - Logger.LogInformation("{preferredBrowserPath} was set to {path}. Ignoring {preferredBrowser} setting.", nameof(configuration.PreferredBrowserPath), configuration.PreferredBrowserPath, nameof(configuration.PreferredBrowser)); - return configuration.PreferredBrowserPath; - } - - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - return configuration.PreferredBrowser switch - { - PreferredBrowser.Chrome => Environment.ExpandEnvironmentVariables(@"%ProgramFiles%\Google\Chrome\Application\chrome.exe"), - PreferredBrowser.Edge => Environment.ExpandEnvironmentVariables(@"%ProgramFiles(x86)%\Microsoft\Edge\Application\msedge.exe"), - PreferredBrowser.EdgeDev => Environment.ExpandEnvironmentVariables(@"%ProgramFiles(x86)%\Microsoft\Edge Dev\Application\msedge.exe"), - PreferredBrowser.EdgeBeta => Environment.ExpandEnvironmentVariables(@"%ProgramFiles(x86)%\Microsoft\Edge Beta\Application\msedge.exe"), - _ => throw new NotSupportedException($"{configuration.PreferredBrowser} is an unsupported browser. Please change your PreferredBrowser setting for {Name}.") - }; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - return configuration.PreferredBrowser switch - { - PreferredBrowser.Chrome => "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", - PreferredBrowser.Edge => "/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge", - PreferredBrowser.EdgeDev => "/Applications/Microsoft Edge Dev.app/Contents/MacOS/Microsoft Edge Dev", - PreferredBrowser.EdgeBeta => "/Applications/Microsoft Edge Dev.app/Contents/MacOS/Microsoft Edge Beta", - _ => throw new NotSupportedException($"{configuration.PreferredBrowser} is an unsupported browser. Please change your PreferredBrowser setting for {Name}.") - }; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - return configuration.PreferredBrowser switch - { - PreferredBrowser.Chrome => "/opt/google/chrome/chrome", - PreferredBrowser.Edge => "/opt/microsoft/msedge/msedge", - PreferredBrowser.EdgeDev => "/opt/microsoft/msedge-dev/msedge", - PreferredBrowser.EdgeBeta => "/opt/microsoft/msedge-beta/msedge", - _ => throw new NotSupportedException($"{configuration.PreferredBrowser} is an unsupported browser. Please change your PreferredBrowser setting for {Name}.") - }; - } - else - { - throw new NotSupportedException("Unsupported operating system."); - } - } - - private Process[] GetBrowserProcesses(string browserPath) - { - return Process.GetProcesses().Where(p => { - try - { - return p.MainModule is not null && p.MainModule.FileName == browserPath; - } - catch (Exception ex) - { - Logger.LogDebug("Error while checking process: {Ex}", ex.Message); - return false; - } - }).ToArray(); - } - - private void InitInspector() - { - var browserPath = string.Empty; - - try - { - browserPath = GetBrowserPath(_configuration); - } - catch (NotSupportedException ex) - { - Logger.LogError(ex, "Error starting {plugin}. Error finding the browser.", Name); - return; - } - - // check if the browser is installed - if (!File.Exists(browserPath)) - { - Logger.LogError("Error starting {plugin}. Browser executable not found at {browserPath}", Name, browserPath); - return; - } - - // find if the process is already running - var processes = GetBrowserProcesses(browserPath); - - if (processes.Any()) - { - var ids = string.Join(", ", processes.Select(p => p.Id.ToString())); - Logger.LogError("Found existing browser process {processName} with IDs {processIds}. Could not start {plugin}. Please close existing browser processes and restart Dev Proxy", browserPath, ids, Name); - return; - } - - var port = GetFreePort(); - webSocket = new WebSocketServer(port, Logger); - webSocket.MessageReceived += SocketMessageReceived; - _ = webSocket.StartAsync(); - - var inspectionUrl = $"http://localhost:9222/devtools/inspector.html?ws=localhost:{port}"; - var args = $"{inspectionUrl} --remote-debugging-port=9222 --profile-directory=devproxy"; - - Logger.LogInformation("{name} available at {inspectionUrl}", Name, inspectionUrl); - - var process = new Process - { - StartInfo = new() - { - FileName = browserPath, - Arguments = args, - // suppress default output - RedirectStandardError = true, - RedirectStandardOutput = true, - UseShellExecute = false - } - }; - process.Start(); - } - - private void SocketMessageReceived(string msg) - { - if (webSocket is null) - { - return; - } - - try - { - var message = JsonSerializer.Deserialize(msg, ProxyUtils.JsonSerializerOptions); - if (message?.Method == "Network.getResponseBody") - { - var requestId = message.Params?.RequestId; - if (requestId is null || - !responseBody.TryGetValue(requestId, out GetResponseBodyResultParams? value) || - // should never happen because the message is sent from devtools - // and Id is required on all socket messages but theoretically - // it is possible - message.Id is null) - { - return; - } - - var result = new GetResponseBodyResult - { - Id = (int)message.Id, - Result = new() - { - Body = value.Body, - Base64Encoded = value.Base64Encoded - } - }; - _ = webSocket.SendAsync(result); - } - } - catch { } - } - - private static string GetRequestId(Titanium.Web.Proxy.Http.Request? request) - { - if (request is null) - { - return string.Empty; - } - - return request.GetHashCode().ToString(); - } - - private async Task BeforeRequestAsync(object sender, ProxyRequestArgs e) - { - if (webSocket?.IsConnected != true) - { - return; - } - - var requestId = GetRequestId(e.Session.HttpClient.Request); - var headers = e.Session.HttpClient.Request.Headers - .ToDictionary(h => h.Name, h => h.Value); - - var requestWillBeSentMessage = new RequestWillBeSentMessage - { - Params = new() - { - RequestId = requestId, - LoaderId = "1", - DocumentUrl = e.Session.HttpClient.Request.Url, - Request = new() - { - Url = e.Session.HttpClient.Request.Url, - Method = e.Session.HttpClient.Request.Method, - Headers = headers, - PostData = e.Session.HttpClient.Request.HasBody ? e.Session.HttpClient.Request.BodyString : null - }, - Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, - WallTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), - Initiator = new() - { - Type = "other" - } - } - }; - await webSocket.SendAsync(requestWillBeSentMessage); - - // must be included to avoid the "Provisional headers are shown" warning - var requestWillBeSentExtraInfoMessage = new RequestWillBeSentExtraInfoMessage - { - Params = new() - { - RequestId = requestId, - // must be included in the message or the message will be rejected - AssociatedCookies = [], - Headers = headers - } - }; - await webSocket.SendAsync(requestWillBeSentExtraInfoMessage); - } - - private async Task AfterResponseAsync(object sender, ProxyResponseArgs e) - { - if (webSocket?.IsConnected != true) - { - return; - } - - var body = new GetResponseBodyResultParams - { - Body = string.Empty, - Base64Encoded = false - }; - if (IsTextResponse(e.Session.HttpClient.Response.ContentType)) - { - body.Body = e.Session.HttpClient.Response.BodyString; - body.Base64Encoded = false; - } - else - { - body.Body = Convert.ToBase64String(e.Session.HttpClient.Response.Body); - body.Base64Encoded = true; - } - responseBody.Add(e.Session.HttpClient.Request.GetHashCode().ToString(), body); - - var requestId = GetRequestId(e.Session.HttpClient.Request); - - var responseReceivedMessage = new ResponseReceivedMessage - { - Params = new() - { - RequestId = requestId, - LoaderId = "1", - Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, - Type = "XHR", - Response = new() - { - Url = e.Session.HttpClient.Request.Url, - Status = e.Session.HttpClient.Response.StatusCode, - StatusText = e.Session.HttpClient.Response.StatusDescription, - Headers = e.Session.HttpClient.Response.Headers - .ToDictionary(h => h.Name, h => h.Value), - MimeType = e.Session.HttpClient.Response.ContentType - } - } - }; - - await webSocket.SendAsync(responseReceivedMessage); - - var loadingFinishedMessage = new LoadingFinishedMessage - { - Params = new() - { - RequestId = requestId, - Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, - EncodedDataLength = e.Session.HttpClient.Response.HasBody ? e.Session.HttpClient.Response.Body.Length : 0 - } - }; - await webSocket.SendAsync(loadingFinishedMessage); - } - - private static bool IsTextResponse(string? contentType) - { - var isTextResponse = false; - - if (contentType is not null && - (contentType.IndexOf("text") > -1 || - contentType.IndexOf("json") > -1)) - { - isTextResponse = true; - } - - return isTextResponse; - } - - private async Task AfterRequestLogAsync(object? sender, RequestLogArgs e) - { - if (webSocket?.IsConnected != true || - e.RequestLog.MessageType == MessageType.InterceptedRequest || - e.RequestLog.MessageType == MessageType.InterceptedResponse) - { - return; - } - - var message = new EntryAddedMessage - { - Params = new() - { - Entry = new() - { - Source = "network", - Text = string.Join(" ", e.RequestLog.MessageLines), - Level = Entry.GetLevel(e.RequestLog.MessageType), - Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), - Url = e.RequestLog.Context?.Session.HttpClient.Request.Url, - NetworkRequestId = GetRequestId(e.RequestLog.Context?.Session.HttpClient.Request) - } - } - }; - await webSocket.SendAsync(message); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Text.Json; +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Plugins.Inspection.CDP; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.Inspection; + +public enum PreferredBrowser +{ + Edge, + Chrome, + EdgeDev, + EdgeBeta +} + +public class DevToolsPluginConfiguration +{ + public PreferredBrowser PreferredBrowser { get; set; } = PreferredBrowser.Edge; + + /// + /// Path to the browser executable. If not set, the plugin will try to find + /// the browser executable based on the PreferredBrowser. + /// + /// Use this value when you install the browser in a non-standard + /// path. + public string PreferredBrowserPath { get; set; } = string.Empty; +} + +public class DevToolsPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private WebSocketServer? webSocket; + private readonly Dictionary responseBody = []; + + public override string Name => nameof(DevToolsPlugin); + private readonly DevToolsPluginConfiguration _configuration = new(); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + ConfigSection?.Bind(_configuration); + + InitInspector(); + + PluginEvents.BeforeRequest += BeforeRequestAsync; + PluginEvents.AfterResponse += AfterResponseAsync; + PluginEvents.AfterRequestLog += AfterRequestLogAsync; + } + + private static int GetFreePort() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + listener.Stop(); + return port; + } + + private string GetBrowserPath(DevToolsPluginConfiguration configuration) + { + if (!string.IsNullOrEmpty(configuration.PreferredBrowserPath)) + { + Logger.LogInformation("{preferredBrowserPath} was set to {path}. Ignoring {preferredBrowser} setting.", nameof(configuration.PreferredBrowserPath), configuration.PreferredBrowserPath, nameof(configuration.PreferredBrowser)); + return configuration.PreferredBrowserPath; + } + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return configuration.PreferredBrowser switch + { + PreferredBrowser.Chrome => Environment.ExpandEnvironmentVariables(@"%ProgramFiles%\Google\Chrome\Application\chrome.exe"), + PreferredBrowser.Edge => Environment.ExpandEnvironmentVariables(@"%ProgramFiles(x86)%\Microsoft\Edge\Application\msedge.exe"), + PreferredBrowser.EdgeDev => Environment.ExpandEnvironmentVariables(@"%ProgramFiles(x86)%\Microsoft\Edge Dev\Application\msedge.exe"), + PreferredBrowser.EdgeBeta => Environment.ExpandEnvironmentVariables(@"%ProgramFiles(x86)%\Microsoft\Edge Beta\Application\msedge.exe"), + _ => throw new NotSupportedException($"{configuration.PreferredBrowser} is an unsupported browser. Please change your PreferredBrowser setting for {Name}.") + }; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return configuration.PreferredBrowser switch + { + PreferredBrowser.Chrome => "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", + PreferredBrowser.Edge => "/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge", + PreferredBrowser.EdgeDev => "/Applications/Microsoft Edge Dev.app/Contents/MacOS/Microsoft Edge Dev", + PreferredBrowser.EdgeBeta => "/Applications/Microsoft Edge Dev.app/Contents/MacOS/Microsoft Edge Beta", + _ => throw new NotSupportedException($"{configuration.PreferredBrowser} is an unsupported browser. Please change your PreferredBrowser setting for {Name}.") + }; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return configuration.PreferredBrowser switch + { + PreferredBrowser.Chrome => "/opt/google/chrome/chrome", + PreferredBrowser.Edge => "/opt/microsoft/msedge/msedge", + PreferredBrowser.EdgeDev => "/opt/microsoft/msedge-dev/msedge", + PreferredBrowser.EdgeBeta => "/opt/microsoft/msedge-beta/msedge", + _ => throw new NotSupportedException($"{configuration.PreferredBrowser} is an unsupported browser. Please change your PreferredBrowser setting for {Name}.") + }; + } + else + { + throw new NotSupportedException("Unsupported operating system."); + } + } + + private Process[] GetBrowserProcesses(string browserPath) + { + return Process.GetProcesses().Where(p => { + try + { + return p.MainModule is not null && p.MainModule.FileName == browserPath; + } + catch (Exception ex) + { + Logger.LogDebug("Error while checking process: {Ex}", ex.Message); + return false; + } + }).ToArray(); + } + + private void InitInspector() + { + var browserPath = string.Empty; + + try + { + browserPath = GetBrowserPath(_configuration); + } + catch (NotSupportedException ex) + { + Logger.LogError(ex, "Error starting {plugin}. Error finding the browser.", Name); + return; + } + + // check if the browser is installed + if (!File.Exists(browserPath)) + { + Logger.LogError("Error starting {plugin}. Browser executable not found at {browserPath}", Name, browserPath); + return; + } + + // find if the process is already running + var processes = GetBrowserProcesses(browserPath); + + if (processes.Any()) + { + var ids = string.Join(", ", processes.Select(p => p.Id.ToString())); + Logger.LogError("Found existing browser process {processName} with IDs {processIds}. Could not start {plugin}. Please close existing browser processes and restart Dev Proxy", browserPath, ids, Name); + return; + } + + var port = GetFreePort(); + webSocket = new WebSocketServer(port, Logger); + webSocket.MessageReceived += SocketMessageReceived; + _ = webSocket.StartAsync(); + + var inspectionUrl = $"http://localhost:9222/devtools/inspector.html?ws=localhost:{port}"; + var args = $"{inspectionUrl} --remote-debugging-port=9222 --profile-directory=devproxy"; + + Logger.LogInformation("{name} available at {inspectionUrl}", Name, inspectionUrl); + + var process = new Process + { + StartInfo = new() + { + FileName = browserPath, + Arguments = args, + // suppress default output + RedirectStandardError = true, + RedirectStandardOutput = true, + UseShellExecute = false + } + }; + process.Start(); + } + + private void SocketMessageReceived(string msg) + { + if (webSocket is null) + { + return; + } + + try + { + var message = JsonSerializer.Deserialize(msg, ProxyUtils.JsonSerializerOptions); + if (message?.Method == "Network.getResponseBody") + { + var requestId = message.Params?.RequestId; + if (requestId is null || + !responseBody.TryGetValue(requestId, out GetResponseBodyResultParams? value) || + // should never happen because the message is sent from devtools + // and Id is required on all socket messages but theoretically + // it is possible + message.Id is null) + { + return; + } + + var result = new GetResponseBodyResult + { + Id = (int)message.Id, + Result = new() + { + Body = value.Body, + Base64Encoded = value.Base64Encoded + } + }; + _ = webSocket.SendAsync(result); + } + } + catch { } + } + + private static string GetRequestId(Titanium.Web.Proxy.Http.Request? request) + { + if (request is null) + { + return string.Empty; + } + + return request.GetHashCode().ToString(); + } + + private async Task BeforeRequestAsync(object sender, ProxyRequestArgs e) + { + if (webSocket?.IsConnected != true) + { + return; + } + + var requestId = GetRequestId(e.Session.HttpClient.Request); + var headers = e.Session.HttpClient.Request.Headers + .ToDictionary(h => h.Name, h => h.Value); + + var requestWillBeSentMessage = new RequestWillBeSentMessage + { + Params = new() + { + RequestId = requestId, + LoaderId = "1", + DocumentUrl = e.Session.HttpClient.Request.Url, + Request = new() + { + Url = e.Session.HttpClient.Request.Url, + Method = e.Session.HttpClient.Request.Method, + Headers = headers, + PostData = e.Session.HttpClient.Request.HasBody ? e.Session.HttpClient.Request.BodyString : null + }, + Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, + WallTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), + Initiator = new() + { + Type = "other" + } + } + }; + await webSocket.SendAsync(requestWillBeSentMessage); + + // must be included to avoid the "Provisional headers are shown" warning + var requestWillBeSentExtraInfoMessage = new RequestWillBeSentExtraInfoMessage + { + Params = new() + { + RequestId = requestId, + // must be included in the message or the message will be rejected + AssociatedCookies = [], + Headers = headers + } + }; + await webSocket.SendAsync(requestWillBeSentExtraInfoMessage); + } + + private async Task AfterResponseAsync(object sender, ProxyResponseArgs e) + { + if (webSocket?.IsConnected != true) + { + return; + } + + var body = new GetResponseBodyResultParams + { + Body = string.Empty, + Base64Encoded = false + }; + if (IsTextResponse(e.Session.HttpClient.Response.ContentType)) + { + body.Body = e.Session.HttpClient.Response.BodyString; + body.Base64Encoded = false; + } + else + { + body.Body = Convert.ToBase64String(e.Session.HttpClient.Response.Body); + body.Base64Encoded = true; + } + responseBody.Add(e.Session.HttpClient.Request.GetHashCode().ToString(), body); + + var requestId = GetRequestId(e.Session.HttpClient.Request); + + var responseReceivedMessage = new ResponseReceivedMessage + { + Params = new() + { + RequestId = requestId, + LoaderId = "1", + Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, + Type = "XHR", + Response = new() + { + Url = e.Session.HttpClient.Request.Url, + Status = e.Session.HttpClient.Response.StatusCode, + StatusText = e.Session.HttpClient.Response.StatusDescription, + Headers = e.Session.HttpClient.Response.Headers + .ToDictionary(h => h.Name, h => h.Value), + MimeType = e.Session.HttpClient.Response.ContentType + } + } + }; + + await webSocket.SendAsync(responseReceivedMessage); + + var loadingFinishedMessage = new LoadingFinishedMessage + { + Params = new() + { + RequestId = requestId, + Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, + EncodedDataLength = e.Session.HttpClient.Response.HasBody ? e.Session.HttpClient.Response.Body.Length : 0 + } + }; + await webSocket.SendAsync(loadingFinishedMessage); + } + + private static bool IsTextResponse(string? contentType) + { + var isTextResponse = false; + + if (contentType is not null && + (contentType.IndexOf("text") > -1 || + contentType.IndexOf("json") > -1)) + { + isTextResponse = true; + } + + return isTextResponse; + } + + private async Task AfterRequestLogAsync(object? sender, RequestLogArgs e) + { + if (webSocket?.IsConnected != true || + e.RequestLog.MessageType == MessageType.InterceptedRequest || + e.RequestLog.MessageType == MessageType.InterceptedResponse) + { + return; + } + + var message = new EntryAddedMessage + { + Params = new() + { + Entry = new() + { + Source = "network", + Text = string.Join(" ", e.RequestLog.Message), + Level = Entry.GetLevel(e.RequestLog.MessageType), + Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + Url = e.RequestLog.Context?.Session.HttpClient.Request.Url, + NetworkRequestId = GetRequestId(e.RequestLog.Context?.Session.HttpClient.Request) + } + } + }; + await webSocket.SendAsync(message); + } +} diff --git a/dev-proxy-plugins/MessageUtils.cs b/dev-proxy-plugins/MessageUtils.cs index 8e4bfb1f..ca3dc268 100644 --- a/dev-proxy-plugins/MessageUtils.cs +++ b/dev-proxy-plugins/MessageUtils.cs @@ -1,23 +1,21 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Titanium.Web.Proxy.Http; - -namespace Microsoft.DevProxy.Plugins; - -internal class MessageUtils -{ - public static string[] BuildUseSdkForErrorsMessage(Request r) => ["To handle API errors more easily, use the Microsoft Graph SDK.", $"More info at {GetMoveToSdkUrl(r)}"]; - - public static string[] BuildUseSdkMessage(Request r) => [ - "To more easily follow best practices for working with Microsoft Graph, ", - "use the Microsoft Graph SDK.", - $"More info at {GetMoveToSdkUrl(r)}" - ]; - - public static string GetMoveToSdkUrl(Request request) - { - // TODO: return language-specific guidance links based on the language detected from the User-Agent - return "https://aka.ms/devproxy/guidance/move-to-js-sdk"; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Titanium.Web.Proxy.Http; + +namespace Microsoft.DevProxy.Plugins; + +internal class MessageUtils +{ + public static string BuildUseSdkForErrorsMessage(Request r) => + $"To handle API errors more easily, use the Microsoft Graph SDK. More info at {GetMoveToSdkUrl(r)}"; + + public static string BuildUseSdkMessage(Request r) => + $"To more easily follow best practices for working with Microsoft Graph, use the Microsoft Graph SDK. More info at {GetMoveToSdkUrl(r)}"; + + public static string GetMoveToSdkUrl(Request request) + { + // TODO: return language-specific guidance links based on the language detected from the User-Agent + return "https://aka.ms/devproxy/guidance/move-to-js-sdk"; + } +} diff --git a/dev-proxy-plugins/Mocks/AuthPlugin.cs b/dev-proxy-plugins/Mocks/AuthPlugin.cs index 43592ce7..88278bcf 100644 --- a/dev-proxy-plugins/Mocks/AuthPlugin.cs +++ b/dev-proxy-plugins/Mocks/AuthPlugin.cs @@ -167,6 +167,7 @@ private async Task OnBeforeRequestAsync(object sender, ProxyRequestArgs e) { if (UrlsToWatch is null || !e.ShouldExecute(UrlsToWatch)) { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); return; } @@ -177,7 +178,7 @@ private async Task OnBeforeRequestAsync(object sender, ProxyRequestArgs e) } else { - Logger.LogRequest(["Request authorized"], MessageType.Normal, new LoggingContext(e.Session)); + Logger.LogRequest("Request authorized", MessageType.Normal, new LoggingContext(e.Session)); } } @@ -205,14 +206,14 @@ private bool AuthorizeApiKeyRequest(SessionEventArgs session) var apiKey = GetApiKey(session); if (apiKey is null) { - Logger.LogRequest(["401 Unauthorized", "API key not found."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest("401 Unauthorized. API key not found.", MessageType.Failed, new LoggingContext(session)); return false; } var isKeyValid = _configuration.ApiKey.AllowedKeys.Contains(apiKey); if (!isKeyValid) { - Logger.LogRequest(["401 Unauthorized", $"API key {apiKey} is not allowed."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. API key {apiKey} is not allowed.", MessageType.Failed, new LoggingContext(session)); } return isKeyValid; @@ -264,7 +265,7 @@ private bool AuthorizeOAuth2Request(SessionEventArgs session) } catch (Exception ex) { - Logger.LogRequest(["401 Unauthorized", $"The specified token is not valid: {ex.Message}"], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not valid: {ex.Message}", MessageType.Failed, new LoggingContext(session)); return false; } } @@ -303,14 +304,14 @@ private bool ValidatePrincipals(ClaimsPrincipal claimsPrincipal, SessionEventArg var principalId = claimsPrincipal.FindFirst("http://schemas.microsoft.com/identity/claims/objectidentifier")?.Value; if (principalId is null) { - Logger.LogRequest(["401 Unauthorized", "The specified token doesn't have the oid claim."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest("401 Unauthorized. The specified token doesn't have the oid claim.", MessageType.Failed, new LoggingContext(session)); return false; } if (!_configuration.OAuth2.AllowedPrincipals.Contains(principalId)) { var principals = string.Join(", ", _configuration.OAuth2.AllowedPrincipals); - Logger.LogRequest(["401 Unauthorized", $"The specified token is not issued for an allowed principal. Allowed principals: {principals}, found: {principalId}"], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not issued for an allowed principal. Allowed principals: {principals}, found: {principalId}", MessageType.Failed, new LoggingContext(session)); return false; } @@ -333,21 +334,21 @@ private bool ValidateApplications(ClaimsPrincipal claimsPrincipal, SessionEventA var tokenVersion = claimsPrincipal.FindFirst("ver")?.Value; if (tokenVersion is null) { - Logger.LogRequest(["401 Unauthorized", "The specified token doesn't have the ver claim."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest("401 Unauthorized. The specified token doesn't have the ver claim.", MessageType.Failed, new LoggingContext(session)); return false; } var appId = claimsPrincipal.FindFirst(tokenVersion == "1.0" ? "appid" : "azp")?.Value; if (appId is null) { - Logger.LogRequest(["401 Unauthorized", $"The specified token doesn't have the {(tokenVersion == "v1.0" ? "appid" : "azp")} claim."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token doesn't have the {(tokenVersion == "v1.0" ? "appid" : "azp")} claim.", MessageType.Failed, new LoggingContext(session)); return false; } if (!_configuration.OAuth2.AllowedApplications.Contains(appId)) { var applications = string.Join(", ", _configuration.OAuth2.AllowedApplications); - Logger.LogRequest(["401 Unauthorized", $"The specified token is not issued by an allowed application. Allowed applications: {applications}, found: {appId}"], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not issued by an allowed application. Allowed applications: {applications}, found: {appId}", MessageType.Failed, new LoggingContext(session)); return false; } @@ -370,14 +371,14 @@ private bool ValidateTenants(ClaimsPrincipal claimsPrincipal, SessionEventArgs s var tenantId = claimsPrincipal.FindFirst("http://schemas.microsoft.com/identity/claims/tenantid")?.Value; if (tenantId is null) { - Logger.LogRequest(["401 Unauthorized", "The specified token doesn't have the tid claim."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest("401 Unauthorized. The specified token doesn't have the tid claim.", MessageType.Failed, new LoggingContext(session)); return false; } if (!_configuration.OAuth2.AllowedTenants.Contains(tenantId)) { var tenants = string.Join(", ", _configuration.OAuth2.AllowedTenants); - Logger.LogRequest(["401 Unauthorized", $"The specified token is not issued by an allowed tenant. Allowed tenants: {tenants}, found: {tenantId}"], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not issued by an allowed tenant. Allowed tenants: {tenants}, found: {tenantId}", MessageType.Failed, new LoggingContext(session)); return false; } @@ -404,7 +405,7 @@ private bool ValidateRoles(ClaimsPrincipal claimsPrincipal, SessionEventArgs ses var rolesRequired = string.Join(", ", _configuration.OAuth2.Roles); if (!_configuration.OAuth2.Roles.Any(r => HasPermission(r, rolesFromTheToken))) { - Logger.LogRequest(["401 Unauthorized", $"The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}"], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}", MessageType.Failed, new LoggingContext(session)); return false; } @@ -431,7 +432,7 @@ private bool ValidateScopes(ClaimsPrincipal claimsPrincipal, SessionEventArgs se var scopesRequired = string.Join(", ", _configuration.OAuth2.Scopes); if (!_configuration.OAuth2.Scopes.Any(s => HasPermission(s, scopesFromTheToken))) { - Logger.LogRequest(["401 Unauthorized", $"The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}"], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}", MessageType.Failed, new LoggingContext(session)); return false; } @@ -462,13 +463,13 @@ private static bool HasPermission(string permission, string permissionString) if (tokenParts is null) { - Logger.LogRequest(["401 Unauthorized", "Authorization header not found."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest("401 Unauthorized. Authorization header not found.", MessageType.Failed, new LoggingContext(session)); return null; } if (tokenParts.Length != 2 || tokenParts[0] != "Bearer") { - Logger.LogRequest(["401 Unauthorized", "The specified token is not a valid Bearer token."], MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest("401 Unauthorized. The specified token is not a valid Bearer token.", MessageType.Failed, new LoggingContext(session)); return null; } diff --git a/dev-proxy-plugins/Mocks/CrudApiPlugin.cs b/dev-proxy-plugins/Mocks/CrudApiPlugin.cs index ba94dfe1..6f88b80a 100644 --- a/dev-proxy-plugins/Mocks/CrudApiPlugin.cs +++ b/dev-proxy-plugins/Mocks/CrudApiPlugin.cs @@ -1,527 +1,535 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using System.Net; -using System.Text.Json.Serialization; -using System.Text.RegularExpressions; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; -using Microsoft.Extensions.Logging; -using System.IdentityModel.Tokens.Jwt; -using System.Diagnostics; -using Microsoft.IdentityModel.Tokens; -using Microsoft.IdentityModel.Protocols.OpenIdConnect; -using Microsoft.IdentityModel.Protocols; -using System.Security.Claims; - -namespace Microsoft.DevProxy.Plugins.Mocks; - -public enum CrudApiActionType -{ - Create, - GetAll, - GetOne, - GetMany, - Merge, - Update, - Delete -} - -public enum CrudApiAuthType -{ - None, - Entra -} - -public class CrudApiEntraAuth -{ - public string Audience { get; set; } = string.Empty; - public string Issuer { get; set; } = string.Empty; - public string[] Scopes { get; set; } = []; - public string[] Roles { get; set; } = []; - public bool ValidateLifetime { get; set; } = false; - public bool ValidateSigningKey { get; set; } = false; -} - -public class CrudApiAction -{ - [System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] - public CrudApiActionType Action { get; set; } = CrudApiActionType.GetAll; - public string Url { get; set; } = string.Empty; - public string? Method { get; set; } - public string Query { get; set; } = string.Empty; - [System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] - public CrudApiAuthType Auth { get; set; } = CrudApiAuthType.None; - public CrudApiEntraAuth? EntraAuthConfig { get; set; } -} - -public class CrudApiConfiguration -{ - public string ApiFile { get; set; } = "api.json"; - public string BaseUrl { get; set; } = string.Empty; - public string DataFile { get; set; } = string.Empty; - public IEnumerable Actions { get; set; } = []; - [System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] - public CrudApiAuthType Auth { get; set; } = CrudApiAuthType.None; - public CrudApiEntraAuth? EntraAuthConfig { get; set; } -} - -public class CrudApiPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - protected CrudApiConfiguration _configuration = new(); - private CrudApiDefinitionLoader? _loader = null; - public override string Name => nameof(CrudApiPlugin); - private IProxyConfiguration? _proxyConfiguration; - private JArray? _data; - private OpenIdConnectConfiguration? _openIdConnectConfiguration; - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - - PluginEvents.BeforeRequest += OnRequestAsync; - - _proxyConfiguration = Context.Configuration; - - _configuration.ApiFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.ApiFile), Path.GetDirectoryName(_proxyConfiguration?.ConfigFile ?? string.Empty) ?? string.Empty); - - _loader = new CrudApiDefinitionLoader(Logger, _configuration); - _loader?.InitApiDefinitionWatcher(); - - if (_configuration.Auth == CrudApiAuthType.Entra && - _configuration.EntraAuthConfig is null) - { - Logger.LogError("Entra auth is enabled but no configuration is provided. API will work anonymously."); - _configuration.Auth = CrudApiAuthType.None; - } - - LoadData(); - await SetupOpenIdConnectConfigurationAsync(); - } - - private async Task SetupOpenIdConnectConfigurationAsync() - { - try - { - var retriever = new OpenIdConnectConfigurationRetriever(); - var configurationManager = new ConfigurationManager("https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration", retriever); - _openIdConnectConfiguration = await configurationManager.GetConfigurationAsync(); - } - catch (Exception ex) - { - Logger.LogError(ex, "An error has occurred while loading OpenIdConnectConfiguration"); - } - } - - private void LoadData() - { - try - { - var dataFilePath = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.DataFile), Path.GetDirectoryName(_configuration.ApiFile) ?? string.Empty); - if (!File.Exists(dataFilePath)) - { - Logger.LogError($"Data file '{dataFilePath}' does not exist. The {_configuration.BaseUrl} API will be disabled."); - _configuration.Actions = []; - return; - } - - var dataString = File.ReadAllText(dataFilePath); - _data = JArray.Parse(dataString); - } - catch (Exception ex) - { - Logger.LogError(ex, "An error has occurred while reading {configFile}", _configuration.DataFile); - } - } - - protected virtual Task OnRequestAsync(object? sender, ProxyRequestArgs e) - { - Request request = e.Session.HttpClient.Request; - ResponseState state = e.ResponseState; - - if (UrlsToWatch is not null && e.ShouldExecute(UrlsToWatch)) - { - if (!AuthorizeRequest(e)) - { - SendUnauthorizedResponse(e.Session); - state.HasBeenSet = true; - return Task.CompletedTask; - } - - var actionAndParams = GetMatchingActionHandler(request); - if (actionAndParams is not null) - { - if (!AuthorizeRequest(e, actionAndParams.Value.action)) - { - SendUnauthorizedResponse(e.Session); - state.HasBeenSet = true; - return Task.CompletedTask; - } - - actionAndParams.Value.handler(e.Session, actionAndParams.Value.action, actionAndParams.Value.parameters); - state.HasBeenSet = true; - } - } - - return Task.CompletedTask; - } - - private bool AuthorizeRequest(ProxyRequestArgs e, CrudApiAction? action = null) - { - var authType = action is null ? _configuration.Auth : action.Auth; - var authConfig = action is null ? _configuration.EntraAuthConfig : action.EntraAuthConfig; - - if (authType == CrudApiAuthType.None) - { - if (action is null) - { - Logger.LogDebug("No auth is required for this API."); - } - return true; - } - - Debug.Assert(authConfig is not null, "EntraAuthConfig is null when auth is required."); - - var token = e.Session.HttpClient.Request.Headers.FirstOrDefault(h => h.Name.Equals("Authorization", StringComparison.OrdinalIgnoreCase))?.Value; - // is there a token - if (string.IsNullOrEmpty(token)) - { - Logger.LogRequest(["401 Unauthorized", "No token found on the request."], MessageType.Failed, new LoggingContext(e.Session)); - return false; - } - - // does the token has a valid format - var tokenHeaderParts = token.Split(' '); - if (tokenHeaderParts.Length != 2 || tokenHeaderParts[0] != "Bearer") - { - Logger.LogRequest(["401 Unauthorized", "The specified token is not a valid Bearer token."], MessageType.Failed, new LoggingContext(e.Session)); - return false; - } - - var handler = new JwtSecurityTokenHandler(); - var validationParameters = new TokenValidationParameters - { - IssuerSigningKeys = _openIdConnectConfiguration?.SigningKeys, - ValidateIssuer = !string.IsNullOrEmpty(authConfig.Issuer), - ValidIssuer = authConfig.Issuer, - ValidateAudience = !string.IsNullOrEmpty(authConfig.Audience), - ValidAudience = authConfig.Audience, - ValidateLifetime = authConfig.ValidateLifetime, - ValidateIssuerSigningKey = authConfig.ValidateSigningKey - }; - if (!authConfig.ValidateSigningKey) - { - // suppress token validation - validationParameters.SignatureValidator = delegate (string token, TokenValidationParameters parameters) - { - var jwt = new JwtSecurityToken(token); - return jwt; - }; - } - - try - { - var claimsPrincipal = handler.ValidateToken(tokenHeaderParts[1], validationParameters, out _); - - // does the token has valid roles/scopes - if (authConfig.Roles.Length != 0) - { - var rolesFromTheToken = string.Join(' ', claimsPrincipal.Claims - .Where(c => c.Type == ClaimTypes.Role) - .Select(c => c.Value)); - - if (!authConfig.Roles.Any(r => HasPermission(r, rolesFromTheToken))) - { - var rolesRequired = string.Join(", ", authConfig.Roles); - - Logger.LogRequest(["401 Unauthorized", $"The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}"], MessageType.Failed, new LoggingContext(e.Session)); - return false; - } - - return true; - } - if (authConfig.Scopes.Length != 0) - { - var scopesFromTheToken = string.Join(' ', claimsPrincipal.Claims - .Where(c => c.Type == "http://schemas.microsoft.com/identity/claims/scope") - .Select(c => c.Value)); - - if (!authConfig.Scopes.Any(s => HasPermission(s, scopesFromTheToken))) - { - var scopesRequired = string.Join(", ", authConfig.Scopes); - - Logger.LogRequest(["401 Unauthorized", $"The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}"], MessageType.Failed, new LoggingContext(e.Session)); - return false; - } - - return true; - } - } - catch (Exception ex) - { - Logger.LogRequest(["401 Unauthorized", $"The specified token is not valid: {ex.Message}"], MessageType.Failed, new LoggingContext(e.Session)); - return false; - } - - return true; - } - - private static bool HasPermission(string permission, string permissionString) - { - if (string.IsNullOrEmpty(permissionString)) - { - return false; - } - - var permissions = permissionString.Split(' '); - return permissions.Contains(permission, StringComparer.OrdinalIgnoreCase); - } - - private static void SendUnauthorizedResponse(SessionEventArgs e) - { - SendJsonResponse("{\"error\":{\"message\":\"Unauthorized\"}}", HttpStatusCode.Unauthorized, e); - } - - private static void SendNotFoundResponse(SessionEventArgs e) - { - SendJsonResponse("{\"error\":{\"message\":\"Not found\"}}", HttpStatusCode.NotFound, e); - } - - private static string ReplaceParams(string query, IDictionary parameters) - { - var result = Regex.Replace(query, "{([^}]+)}", new MatchEvaluator(m => - { - return $"{{{m.Groups[1].Value.Replace('-', '_')}}}"; - })); - foreach (var param in parameters) - { - result = result.Replace($"{{{param.Key}}}", param.Value); - } - return result; - } - - private static void SendEmptyResponse(HttpStatusCode statusCode, SessionEventArgs e) - { - var headers = new List(); - if (e.HttpClient.Request.Headers.Any(h => h.Name == "Origin")) - { - headers.Add(new HttpHeader("access-control-allow-origin", "*")); - } - e.GenericResponse("", statusCode, headers); - } - - private static void SendJsonResponse(string body, HttpStatusCode statusCode, SessionEventArgs e) - { - var headers = new List { - new("content-type", "application/json; charset=utf-8") - }; - if (e.HttpClient.Request.Headers.Any(h => h.Name == "Origin")) - { - headers.Add(new HttpHeader("access-control-allow-origin", "*")); - } - e.GenericResponse(body, statusCode, headers); - } - - private void GetAll(SessionEventArgs e, CrudApiAction action, IDictionary parameters) - { - SendJsonResponse(JsonConvert.SerializeObject(_data, Formatting.Indented), HttpStatusCode.OK, e); - Logger.LogRequest([$"200 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - } - - private void GetOne(SessionEventArgs e, CrudApiAction action, IDictionary parameters) - { - try - { - var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); - if (item is null) - { - SendNotFoundResponse(e); - Logger.LogRequest([$"404 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - return; - } - - SendJsonResponse(JsonConvert.SerializeObject(item, Formatting.Indented), HttpStatusCode.OK, e); - Logger.LogRequest([$"200 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - } - catch (Exception ex) - { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest([$"500 {action.Url}"], MessageType.Failed, new LoggingContext(e)); - } - } - - private void GetMany(SessionEventArgs e, CrudApiAction action, IDictionary parameters) - { - try - { - var items = (_data?.SelectTokens(ReplaceParams(action.Query, parameters))) ?? []; - SendJsonResponse(JsonConvert.SerializeObject(items, Formatting.Indented), HttpStatusCode.OK, e); - Logger.LogRequest([$"200 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - } - catch (Exception ex) - { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest([$"500 {action.Url}"], MessageType.Failed, new LoggingContext(e)); - } - } - - private void Create(SessionEventArgs e, CrudApiAction action, IDictionary parameters) - { - try - { - _data?.Add(JObject.Parse(e.HttpClient.Request.BodyString)); - SendJsonResponse(JsonConvert.SerializeObject(e.HttpClient.Request.BodyString, Formatting.Indented), HttpStatusCode.Created, e); - Logger.LogRequest([$"201 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - } - catch (Exception ex) - { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest([$"500 {action.Url}"], MessageType.Failed, new LoggingContext(e)); - } - } - - private void Merge(SessionEventArgs e, CrudApiAction action, IDictionary parameters) - { - try - { - var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); - if (item is null) - { - SendNotFoundResponse(e); - Logger.LogRequest([$"404 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - return; - } - var update = JObject.Parse(e.HttpClient.Request.BodyString); - ((JContainer)item)?.Merge(update); - SendEmptyResponse(HttpStatusCode.NoContent, e); - Logger.LogRequest([$"204 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - } - catch (Exception ex) - { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest([$"500 {action.Url}"], MessageType.Failed, new LoggingContext(e)); - } - } - - private void Update(SessionEventArgs e, CrudApiAction action, IDictionary parameters) - { - try - { - var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); - if (item is null) - { - SendNotFoundResponse(e); - Logger.LogRequest([$"404 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - return; - } - var update = JObject.Parse(e.HttpClient.Request.BodyString); - ((JContainer)item)?.Replace(update); - SendEmptyResponse(HttpStatusCode.NoContent, e); - Logger.LogRequest([$"204 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - } - catch (Exception ex) - { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest([$"500 {action.Url}"], MessageType.Failed, new LoggingContext(e)); - } - } - - private void Delete(SessionEventArgs e, CrudApiAction action, IDictionary parameters) - { - try - { - var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); - if (item is null) - { - SendNotFoundResponse(e); - Logger.LogRequest([$"404 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - return; - } - - item?.Remove(); - SendEmptyResponse(HttpStatusCode.NoContent, e); - Logger.LogRequest([$"204 {action.Url}"], MessageType.Mocked, new LoggingContext(e)); - } - catch (Exception ex) - { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest([$"500 {action.Url}"], MessageType.Failed, new LoggingContext(e)); - } - } - - private (Action> handler, CrudApiAction action, IDictionary parameters)? GetMatchingActionHandler(Request request) - { - if (_configuration.Actions is null || - !_configuration.Actions.Any()) - { - return null; - } - - var parameterMatchEvaluator = new MatchEvaluator(m => - { - var paramName = m.Value.Trim('{', '}').Replace('-', '_'); - return $"(?<{paramName}>[^/&]+)"; - }); - - var parameters = new Dictionary(); - var action = _configuration.Actions.FirstOrDefault(action => - { - if (action.Method != request.Method) return false; - var absoluteActionUrl = (_configuration.BaseUrl + action.Url).Replace("//", "/", 8); - - if (absoluteActionUrl == request.Url) - { - return true; - } - - // check if the action contains parameters - // if it doesn't, it's not a match for the current request for sure - if (!absoluteActionUrl.Contains('{')) - { - return false; - } - - // convert parameters into named regex groups - var urlRegex = Regex.Replace(Regex.Escape(absoluteActionUrl).Replace("\\{", "{"), "({[^}]+})", parameterMatchEvaluator); - var match = Regex.Match(request.Url, urlRegex); - if (!match.Success) - { - return false; - } - - foreach (var groupName in match.Groups.Keys) - { - if (groupName == "0") - { - continue; - } - parameters.Add(groupName, Uri.UnescapeDataString(match.Groups[groupName].Value)); - } - return true; - }); - - if (action is null) - { - return null; - } - - return (handler: action.Action switch - { - CrudApiActionType.Create => Create, - CrudApiActionType.GetAll => GetAll, - CrudApiActionType.GetOne => GetOne, - CrudApiActionType.GetMany => GetMany, - CrudApiActionType.Merge => Merge, - CrudApiActionType.Update => Update, - CrudApiActionType.Delete => Delete, - _ => throw new NotImplementedException() - }, action, parameters); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using System.Net; +using System.Text.Json.Serialization; +using System.Text.RegularExpressions; +using Titanium.Web.Proxy.EventArguments; +using Titanium.Web.Proxy.Http; +using Titanium.Web.Proxy.Models; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Microsoft.Extensions.Logging; +using System.IdentityModel.Tokens.Jwt; +using System.Diagnostics; +using Microsoft.IdentityModel.Tokens; +using Microsoft.IdentityModel.Protocols.OpenIdConnect; +using Microsoft.IdentityModel.Protocols; +using System.Security.Claims; + +namespace Microsoft.DevProxy.Plugins.Mocks; + +public enum CrudApiActionType +{ + Create, + GetAll, + GetOne, + GetMany, + Merge, + Update, + Delete +} + +public enum CrudApiAuthType +{ + None, + Entra +} + +public class CrudApiEntraAuth +{ + public string Audience { get; set; } = string.Empty; + public string Issuer { get; set; } = string.Empty; + public string[] Scopes { get; set; } = []; + public string[] Roles { get; set; } = []; + public bool ValidateLifetime { get; set; } = false; + public bool ValidateSigningKey { get; set; } = false; +} + +public class CrudApiAction +{ + [System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] + public CrudApiActionType Action { get; set; } = CrudApiActionType.GetAll; + public string Url { get; set; } = string.Empty; + public string? Method { get; set; } + public string Query { get; set; } = string.Empty; + [System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] + public CrudApiAuthType Auth { get; set; } = CrudApiAuthType.None; + public CrudApiEntraAuth? EntraAuthConfig { get; set; } +} + +public class CrudApiConfiguration +{ + public string ApiFile { get; set; } = "api.json"; + public string BaseUrl { get; set; } = string.Empty; + public string DataFile { get; set; } = string.Empty; + public IEnumerable Actions { get; set; } = []; + [System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] + public CrudApiAuthType Auth { get; set; } = CrudApiAuthType.None; + public CrudApiEntraAuth? EntraAuthConfig { get; set; } +} + +public class CrudApiPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + protected CrudApiConfiguration _configuration = new(); + private CrudApiDefinitionLoader? _loader = null; + public override string Name => nameof(CrudApiPlugin); + private IProxyConfiguration? _proxyConfiguration; + private JArray? _data; + private OpenIdConnectConfiguration? _openIdConnectConfiguration; + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + + PluginEvents.BeforeRequest += OnRequestAsync; + + _proxyConfiguration = Context.Configuration; + + _configuration.ApiFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.ApiFile), Path.GetDirectoryName(_proxyConfiguration?.ConfigFile ?? string.Empty) ?? string.Empty); + + _loader = new CrudApiDefinitionLoader(Logger, _configuration); + _loader?.InitApiDefinitionWatcher(); + + if (_configuration.Auth == CrudApiAuthType.Entra && + _configuration.EntraAuthConfig is null) + { + Logger.LogError("Entra auth is enabled but no configuration is provided. API will work anonymously."); + _configuration.Auth = CrudApiAuthType.None; + } + + LoadData(); + await SetupOpenIdConnectConfigurationAsync(); + } + + private async Task SetupOpenIdConnectConfigurationAsync() + { + try + { + var retriever = new OpenIdConnectConfigurationRetriever(); + var configurationManager = new ConfigurationManager("https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration", retriever); + _openIdConnectConfiguration = await configurationManager.GetConfigurationAsync(); + } + catch (Exception ex) + { + Logger.LogError(ex, "An error has occurred while loading OpenIdConnectConfiguration"); + } + } + + private void LoadData() + { + try + { + var dataFilePath = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.DataFile), Path.GetDirectoryName(_configuration.ApiFile) ?? string.Empty); + if (!File.Exists(dataFilePath)) + { + Logger.LogError($"Data file '{dataFilePath}' does not exist. The {_configuration.BaseUrl} API will be disabled."); + _configuration.Actions = []; + return; + } + + var dataString = File.ReadAllText(dataFilePath); + _data = JArray.Parse(dataString); + } + catch (Exception ex) + { + Logger.LogError(ex, "An error has occurred while reading {configFile}", _configuration.DataFile); + } + } + + protected virtual Task OnRequestAsync(object? sender, ProxyRequestArgs e) + { + Request request = e.Session.HttpClient.Request; + ResponseState state = e.ResponseState; + + if (UrlsToWatch is null || + !e.ShouldExecute(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + if (!AuthorizeRequest(e)) + { + SendUnauthorizedResponse(e.Session); + state.HasBeenSet = true; + return Task.CompletedTask; + } + + var actionAndParams = GetMatchingActionHandler(request); + if (actionAndParams is not null) + { + if (!AuthorizeRequest(e, actionAndParams.Value.action)) + { + SendUnauthorizedResponse(e.Session); + state.HasBeenSet = true; + return Task.CompletedTask; + } + + actionAndParams.Value.handler(e.Session, actionAndParams.Value.action, actionAndParams.Value.parameters); + state.HasBeenSet = true; + } + else + { + Logger.LogRequest("Did not match any action", MessageType.Skipped, new LoggingContext(e.Session)); + } + + return Task.CompletedTask; + } + + private bool AuthorizeRequest(ProxyRequestArgs e, CrudApiAction? action = null) + { + var authType = action is null ? _configuration.Auth : action.Auth; + var authConfig = action is null ? _configuration.EntraAuthConfig : action.EntraAuthConfig; + + if (authType == CrudApiAuthType.None) + { + if (action is null) + { + Logger.LogDebug("No auth is required for this API."); + } + return true; + } + + Debug.Assert(authConfig is not null, "EntraAuthConfig is null when auth is required."); + + var token = e.Session.HttpClient.Request.Headers.FirstOrDefault(h => h.Name.Equals("Authorization", StringComparison.OrdinalIgnoreCase))?.Value; + // is there a token + if (string.IsNullOrEmpty(token)) + { + Logger.LogRequest("401 Unauthorized. No token found on the request.", MessageType.Failed, new LoggingContext(e.Session)); + return false; + } + + // does the token has a valid format + var tokenHeaderParts = token.Split(' '); + if (tokenHeaderParts.Length != 2 || tokenHeaderParts[0] != "Bearer") + { + Logger.LogRequest("401 Unauthorized. The specified token is not a valid Bearer token.", MessageType.Failed, new LoggingContext(e.Session)); + return false; + } + + var handler = new JwtSecurityTokenHandler(); + var validationParameters = new TokenValidationParameters + { + IssuerSigningKeys = _openIdConnectConfiguration?.SigningKeys, + ValidateIssuer = !string.IsNullOrEmpty(authConfig.Issuer), + ValidIssuer = authConfig.Issuer, + ValidateAudience = !string.IsNullOrEmpty(authConfig.Audience), + ValidAudience = authConfig.Audience, + ValidateLifetime = authConfig.ValidateLifetime, + ValidateIssuerSigningKey = authConfig.ValidateSigningKey + }; + if (!authConfig.ValidateSigningKey) + { + // suppress token validation + validationParameters.SignatureValidator = delegate (string token, TokenValidationParameters parameters) + { + var jwt = new JwtSecurityToken(token); + return jwt; + }; + } + + try + { + var claimsPrincipal = handler.ValidateToken(tokenHeaderParts[1], validationParameters, out _); + + // does the token has valid roles/scopes + if (authConfig.Roles.Length != 0) + { + var rolesFromTheToken = string.Join(' ', claimsPrincipal.Claims + .Where(c => c.Type == ClaimTypes.Role) + .Select(c => c.Value)); + + if (!authConfig.Roles.Any(r => HasPermission(r, rolesFromTheToken))) + { + var rolesRequired = string.Join(", ", authConfig.Roles); + + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}", MessageType.Failed, new LoggingContext(e.Session)); + return false; + } + + return true; + } + if (authConfig.Scopes.Length != 0) + { + var scopesFromTheToken = string.Join(' ', claimsPrincipal.Claims + .Where(c => c.Type == "http://schemas.microsoft.com/identity/claims/scope") + .Select(c => c.Value)); + + if (!authConfig.Scopes.Any(s => HasPermission(s, scopesFromTheToken))) + { + var scopesRequired = string.Join(", ", authConfig.Scopes); + + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}", MessageType.Failed, new LoggingContext(e.Session)); + return false; + } + + return true; + } + } + catch (Exception ex) + { + Logger.LogRequest($"401 Unauthorized. The specified token is not valid: {ex.Message}", MessageType.Failed, new LoggingContext(e.Session)); + return false; + } + + return true; + } + + private static bool HasPermission(string permission, string permissionString) + { + if (string.IsNullOrEmpty(permissionString)) + { + return false; + } + + var permissions = permissionString.Split(' '); + return permissions.Contains(permission, StringComparer.OrdinalIgnoreCase); + } + + private static void SendUnauthorizedResponse(SessionEventArgs e) + { + SendJsonResponse("{\"error\":{\"message\":\"Unauthorized\"}}", HttpStatusCode.Unauthorized, e); + } + + private static void SendNotFoundResponse(SessionEventArgs e) + { + SendJsonResponse("{\"error\":{\"message\":\"Not found\"}}", HttpStatusCode.NotFound, e); + } + + private static string ReplaceParams(string query, IDictionary parameters) + { + var result = Regex.Replace(query, "{([^}]+)}", new MatchEvaluator(m => + { + return $"{{{m.Groups[1].Value.Replace('-', '_')}}}"; + })); + foreach (var param in parameters) + { + result = result.Replace($"{{{param.Key}}}", param.Value); + } + return result; + } + + private static void SendEmptyResponse(HttpStatusCode statusCode, SessionEventArgs e) + { + var headers = new List(); + if (e.HttpClient.Request.Headers.Any(h => h.Name == "Origin")) + { + headers.Add(new HttpHeader("access-control-allow-origin", "*")); + } + e.GenericResponse("", statusCode, headers); + } + + private static void SendJsonResponse(string body, HttpStatusCode statusCode, SessionEventArgs e) + { + var headers = new List { + new("content-type", "application/json; charset=utf-8") + }; + if (e.HttpClient.Request.Headers.Any(h => h.Name == "Origin")) + { + headers.Add(new HttpHeader("access-control-allow-origin", "*")); + } + e.GenericResponse(body, statusCode, headers); + } + + private void GetAll(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + { + SendJsonResponse(JsonConvert.SerializeObject(_data, Formatting.Indented), HttpStatusCode.OK, e); + Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + } + + private void GetOne(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + { + try + { + var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); + if (item is null) + { + SendNotFoundResponse(e); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + return; + } + + SendJsonResponse(JsonConvert.SerializeObject(item, Formatting.Indented), HttpStatusCode.OK, e); + Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + } + catch (Exception ex) + { + SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new LoggingContext(e)); + } + } + + private void GetMany(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + { + try + { + var items = (_data?.SelectTokens(ReplaceParams(action.Query, parameters))) ?? []; + SendJsonResponse(JsonConvert.SerializeObject(items, Formatting.Indented), HttpStatusCode.OK, e); + Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + } + catch (Exception ex) + { + SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new LoggingContext(e)); + } + } + + private void Create(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + { + try + { + _data?.Add(JObject.Parse(e.HttpClient.Request.BodyString)); + SendJsonResponse(JsonConvert.SerializeObject(e.HttpClient.Request.BodyString, Formatting.Indented), HttpStatusCode.Created, e); + Logger.LogRequest($"201 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + } + catch (Exception ex) + { + SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new LoggingContext(e)); + } + } + + private void Merge(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + { + try + { + var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); + if (item is null) + { + SendNotFoundResponse(e); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + return; + } + var update = JObject.Parse(e.HttpClient.Request.BodyString); + ((JContainer)item)?.Merge(update); + SendEmptyResponse(HttpStatusCode.NoContent, e); + Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + } + catch (Exception ex) + { + SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new LoggingContext(e)); + } + } + + private void Update(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + { + try + { + var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); + if (item is null) + { + SendNotFoundResponse(e); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + return; + } + var update = JObject.Parse(e.HttpClient.Request.BodyString); + ((JContainer)item)?.Replace(update); + SendEmptyResponse(HttpStatusCode.NoContent, e); + Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + } + catch (Exception ex) + { + SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new LoggingContext(e)); + } + } + + private void Delete(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + { + try + { + var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); + if (item is null) + { + SendNotFoundResponse(e); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + return; + } + + item?.Remove(); + SendEmptyResponse(HttpStatusCode.NoContent, e); + Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, new LoggingContext(e)); + } + catch (Exception ex) + { + SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new LoggingContext(e)); + } + } + + private (Action> handler, CrudApiAction action, IDictionary parameters)? GetMatchingActionHandler(Request request) + { + if (_configuration.Actions is null || + !_configuration.Actions.Any()) + { + return null; + } + + var parameterMatchEvaluator = new MatchEvaluator(m => + { + var paramName = m.Value.Trim('{', '}').Replace('-', '_'); + return $"(?<{paramName}>[^/&]+)"; + }); + + var parameters = new Dictionary(); + var action = _configuration.Actions.FirstOrDefault(action => + { + if (action.Method != request.Method) return false; + var absoluteActionUrl = (_configuration.BaseUrl + action.Url).Replace("//", "/", 8); + + if (absoluteActionUrl == request.Url) + { + return true; + } + + // check if the action contains parameters + // if it doesn't, it's not a match for the current request for sure + if (!absoluteActionUrl.Contains('{')) + { + return false; + } + + // convert parameters into named regex groups + var urlRegex = Regex.Replace(Regex.Escape(absoluteActionUrl).Replace("\\{", "{"), "({[^}]+})", parameterMatchEvaluator); + var match = Regex.Match(request.Url, urlRegex); + if (!match.Success) + { + return false; + } + + foreach (var groupName in match.Groups.Keys) + { + if (groupName == "0") + { + continue; + } + parameters.Add(groupName, Uri.UnescapeDataString(match.Groups[groupName].Value)); + } + return true; + }); + + if (action is null) + { + return null; + } + + return (handler: action.Action switch + { + CrudApiActionType.Create => Create, + CrudApiActionType.GetAll => GetAll, + CrudApiActionType.GetOne => GetOne, + CrudApiActionType.GetMany => GetMany, + CrudApiActionType.Merge => Merge, + CrudApiActionType.Update => Update, + CrudApiActionType.Delete => Delete, + _ => throw new NotImplementedException() + }, action, parameters); + } +} diff --git a/dev-proxy-plugins/Mocks/GraphConnectorNotificationPlugin.cs b/dev-proxy-plugins/Mocks/GraphConnectorNotificationPlugin.cs index 00c53e75..d944ba94 100644 --- a/dev-proxy-plugins/Mocks/GraphConnectorNotificationPlugin.cs +++ b/dev-proxy-plugins/Mocks/GraphConnectorNotificationPlugin.cs @@ -1,186 +1,194 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.IdentityModel.Tokens.Jwt; -using System.Net; -using System.Text; -using System.Text.Json; -using Microsoft.DevProxy.Abstractions; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.IdentityModel.Tokens; -using Titanium.Web.Proxy.EventArguments; - -namespace Microsoft.DevProxy.Plugins.Mocks; - -public class GraphConnectorNotificationConfiguration : MockRequestConfiguration -{ - public string? Audience { get; set; } - public string? Tenant { get; set; } -} - -public class GraphConnectorNotificationPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : MockRequestPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private string? _ticket = null; - private readonly GraphConnectorNotificationConfiguration _graphConnectorConfiguration = new(); - - public override string Name => nameof(GraphConnectorNotificationPlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - ConfigSection?.Bind(_graphConnectorConfiguration); - _graphConnectorConfiguration.MockFile = _configuration.MockFile; - _graphConnectorConfiguration.Request = _configuration.Request; - - PluginEvents.BeforeRequest += OnBeforeRequestAsync; - } - - private Task OnBeforeRequestAsync(object sender, ProxyRequestArgs e) - { - if (!ProxyUtils.IsGraphRequest(e.Session.HttpClient.Request)) - { - return Task.CompletedTask; - } - - VerifyTicket(e.Session); - return Task.CompletedTask; - } - - private void VerifyTicket(SessionEventArgs session) - { - if (_ticket is null) - { - return; - } - - var request = session.HttpClient.Request; - - if (request.Method != "POST" && request.Method != "DELETE") - { - return; - } - - if ((request.Method == "POST" && - !request.RequestUri.AbsolutePath.EndsWith("/external/connections", StringComparison.OrdinalIgnoreCase)) || - (request.Method == "DELETE" && - !request.RequestUri.AbsolutePath.Contains("/external/connections/", StringComparison.OrdinalIgnoreCase))) - { - return; - } - - var ticketFromHeader = request.Headers.FirstOrDefault(h => h.Name.Equals("GraphConnectors-Ticket", StringComparison.OrdinalIgnoreCase))?.Value; - if (ticketFromHeader is null) - { - Logger.LogRequest(["No ticket header found in the Graph connector notification"], MessageType.Failed, new LoggingContext(session)); - return; - } - - if (ticketFromHeader != _ticket) - { - Logger.LogRequest([$"Ticket on the request does not match the expected ticket. Expected: {_ticket}. Request: {ticketFromHeader}"], MessageType.Failed, new LoggingContext(session)); - } - } - - protected override async Task OnMockRequestAsync(object sender, EventArgs e) - { - if (_configuration.Request is null) - { - Logger.LogDebug("No mock request is configured. Skipping."); - return; - } - - using var httpClient = new HttpClient(); - var requestMessage = GetRequestMessage(); - if (requestMessage.Content is null) - { - Logger.LogError("No body found in the mock request. Skipping."); - return; - } - var requestBody = await requestMessage.Content.ReadAsStringAsync(); - requestBody = requestBody.Replace("@dynamic.validationToken", GetJwtToken()); - requestMessage.Content = new StringContent(requestBody, Encoding.UTF8, "application/json"); - LoadTicket(); - - try - { - Logger.LogRequest(["Sending Graph connector notification"], MessageType.Mocked, _configuration.Request.Method, _configuration.Request.Url); - - var response = await httpClient.SendAsync(requestMessage); - - if (response.StatusCode != HttpStatusCode.Accepted) - { - Logger.LogRequest([$"Incorrect response status code {(int)response.StatusCode} {response.StatusCode}. Expected: 202 Accepted"], MessageType.Failed, _configuration.Request.Method, _configuration.Request.Url); - } - - if (response.Content is not null) - { - var content = await response.Content.ReadAsStringAsync(); - if (!string.IsNullOrEmpty(content)) - { - Logger.LogRequest(["Received response body while empty response expected"], MessageType.Failed, _configuration.Request.Method, _configuration.Request.Url); - } - } - } - catch (Exception ex) - { - Logger.LogError(ex, "An error has occurred while sending the Graph connector notification to {url}", _configuration.Request.Url); - } - } - - private string GetJwtToken() - { - var signingCredentials = new X509SigningCredentials(Context.Certificate, SecurityAlgorithms.RsaSha256); - - var tokenHandler = new JwtSecurityTokenHandler(); - var tokenDescriptor = new SecurityTokenDescriptor - { - Claims = new Dictionary - { - // Microsoft Graph Change Tracking - { "azp", "0bf30f3b-4a52-48df-9a82-234910c4a086" }, - // client cert auth - { "azpacr", "2" }, - { "tid", _graphConnectorConfiguration.Tenant ?? "" }, - { "ver", "2.0" } - - }, - Expires = DateTime.UtcNow.AddMinutes(60), - Issuer = $"https://login.microsoftonline.com/{_graphConnectorConfiguration.Tenant}/v2.0", - Audience = _graphConnectorConfiguration.Audience, - SigningCredentials = signingCredentials - }; - - var token = tokenHandler.CreateToken(tokenDescriptor); - return tokenHandler.WriteToken(token); - } - - private void LoadTicket() - { - if (_ticket is not null) - { - return; - } - - if (_configuration.Request?.Body is null) - { - Logger.LogWarning("No body found in the Graph connector notification. Ticket will not be loaded."); - return; - } - - try - { - var body = (JsonElement)_configuration.Request.Body; - _ticket = body.Get("value")?.Get(0)?.Get("resourceData")?.Get("connectorsTicket")?.GetString(); - - if (string.IsNullOrEmpty(_ticket)) - { - Logger.LogError("No ticket found in the Graph connector notification body"); - } - } - catch (Exception ex) - { - Logger.LogError(ex, "An error has occurred while reading the ticket from the Graph connector notification body"); - } - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.IdentityModel.Tokens.Jwt; +using System.Net; +using System.Text; +using System.Text.Json; +using Microsoft.DevProxy.Abstractions; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.IdentityModel.Tokens; +using Titanium.Web.Proxy.EventArguments; + +namespace Microsoft.DevProxy.Plugins.Mocks; + +public class GraphConnectorNotificationConfiguration : MockRequestConfiguration +{ + public string? Audience { get; set; } + public string? Tenant { get; set; } +} + +public class GraphConnectorNotificationPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : MockRequestPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private string? _ticket = null; + private readonly GraphConnectorNotificationConfiguration _graphConnectorConfiguration = new(); + + public override string Name => nameof(GraphConnectorNotificationPlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + ConfigSection?.Bind(_graphConnectorConfiguration); + _graphConnectorConfiguration.MockFile = _configuration.MockFile; + _graphConnectorConfiguration.Request = _configuration.Request; + + PluginEvents.BeforeRequest += OnBeforeRequestAsync; + } + + private Task OnBeforeRequestAsync(object sender, ProxyRequestArgs e) + { + if (!ProxyUtils.IsGraphRequest(e.Session.HttpClient.Request)) + { + Logger.LogRequest("Request is not a Microsoft Graph request", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + VerifyTicket(e.Session); + return Task.CompletedTask; + } + + private void VerifyTicket(SessionEventArgs session) + { + if (_ticket is null) + { + Logger.LogRequest("No ticket found in the Graph request", MessageType.Skipped, new LoggingContext(session)); + return; + } + + var request = session.HttpClient.Request; + + if (request.Method != "POST" && request.Method != "DELETE") + { + Logger.LogRequest("Skipping non-POST and -DELETE request", MessageType.Skipped, new LoggingContext(session)); + return; + } + + if ((request.Method == "POST" && + !request.RequestUri.AbsolutePath.EndsWith("/external/connections", StringComparison.OrdinalIgnoreCase)) || + (request.Method == "DELETE" && + !request.RequestUri.AbsolutePath.Contains("/external/connections/", StringComparison.OrdinalIgnoreCase))) + { + Logger.LogRequest("Skipping non-connection request", MessageType.Skipped, new LoggingContext(session)); + return; + } + + var ticketFromHeader = request.Headers.FirstOrDefault(h => h.Name.Equals("GraphConnectors-Ticket", StringComparison.OrdinalIgnoreCase))?.Value; + if (ticketFromHeader is null) + { + Logger.LogRequest("No ticket header found in the Graph connector notification", MessageType.Failed, new LoggingContext(session)); + return; + } + + if (ticketFromHeader != _ticket) + { + Logger.LogRequest($"Ticket on the request does not match the expected ticket. Expected: {_ticket}. Request: {ticketFromHeader}", MessageType.Failed, new LoggingContext(session)); + } + else + { + Logger.LogRequest("Ticket verified", MessageType.Normal, new LoggingContext(session)); + } + } + + protected override async Task OnMockRequestAsync(object sender, EventArgs e) + { + if (_configuration.Request is null) + { + Logger.LogDebug("No mock request is configured. Skipping."); + return; + } + + using var httpClient = new HttpClient(); + var requestMessage = GetRequestMessage(); + if (requestMessage.Content is null) + { + Logger.LogError("No body found in the mock request. Skipping."); + return; + } + var requestBody = await requestMessage.Content.ReadAsStringAsync(); + requestBody = requestBody.Replace("@dynamic.validationToken", GetJwtToken()); + requestMessage.Content = new StringContent(requestBody, Encoding.UTF8, "application/json"); + LoadTicket(); + + try + { + Logger.LogRequest("Sending Graph connector notification", MessageType.Mocked, _configuration.Request.Method, _configuration.Request.Url); + + var response = await httpClient.SendAsync(requestMessage); + + if (response.StatusCode != HttpStatusCode.Accepted) + { + Logger.LogRequest($"Incorrect response status code {(int)response.StatusCode} {response.StatusCode}. Expected: 202 Accepted", MessageType.Failed, _configuration.Request.Method, _configuration.Request.Url); + } + + if (response.Content is not null) + { + var content = await response.Content.ReadAsStringAsync(); + if (!string.IsNullOrEmpty(content)) + { + Logger.LogRequest("Received response body while empty response expected", MessageType.Failed, _configuration.Request.Method, _configuration.Request.Url); + } + } + } + catch (Exception ex) + { + Logger.LogError(ex, "An error has occurred while sending the Graph connector notification to {url}", _configuration.Request.Url); + } + } + + private string GetJwtToken() + { + var signingCredentials = new X509SigningCredentials(Context.Certificate, SecurityAlgorithms.RsaSha256); + + var tokenHandler = new JwtSecurityTokenHandler(); + var tokenDescriptor = new SecurityTokenDescriptor + { + Claims = new Dictionary + { + // Microsoft Graph Change Tracking + { "azp", "0bf30f3b-4a52-48df-9a82-234910c4a086" }, + // client cert auth + { "azpacr", "2" }, + { "tid", _graphConnectorConfiguration.Tenant ?? "" }, + { "ver", "2.0" } + + }, + Expires = DateTime.UtcNow.AddMinutes(60), + Issuer = $"https://login.microsoftonline.com/{_graphConnectorConfiguration.Tenant}/v2.0", + Audience = _graphConnectorConfiguration.Audience, + SigningCredentials = signingCredentials + }; + + var token = tokenHandler.CreateToken(tokenDescriptor); + return tokenHandler.WriteToken(token); + } + + private void LoadTicket() + { + if (_ticket is not null) + { + return; + } + + if (_configuration.Request?.Body is null) + { + Logger.LogWarning("No body found in the Graph connector notification. Ticket will not be loaded."); + return; + } + + try + { + var body = (JsonElement)_configuration.Request.Body; + _ticket = body.Get("value")?.Get(0)?.Get("resourceData")?.Get("connectorsTicket")?.GetString(); + + if (string.IsNullOrEmpty(_ticket)) + { + Logger.LogError("No ticket found in the Graph connector notification body"); + } + } + catch (Exception ex) + { + Logger.LogError(ex, "An error has occurred while reading the ticket from the Graph connector notification body"); + } + } +} diff --git a/dev-proxy-plugins/Mocks/GraphMockResponsePlugin.cs b/dev-proxy-plugins/Mocks/GraphMockResponsePlugin.cs index 0684f27d..6f20368d 100644 --- a/dev-proxy-plugins/Mocks/GraphMockResponsePlugin.cs +++ b/dev-proxy-plugins/Mocks/GraphMockResponsePlugin.cs @@ -1,187 +1,188 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Net; -using System.Text.Json; -using System.Text.RegularExpressions; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Plugins.Behavior; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Titanium.Web.Proxy.Models; - -namespace Microsoft.DevProxy.Plugins.Mocks; - -public class GraphMockResponsePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : MockResponsePlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(GraphMockResponsePlugin); - - protected override async Task OnRequestAsync(object? sender, ProxyRequestArgs e) - { - if (_configuration.NoMocks) - { - // mocking has been disabled. Nothing to do - return; - } - - if (!ProxyUtils.IsGraphBatchUrl(e.Session.HttpClient.Request.RequestUri)) - { - // not a batch request, use the basic mock functionality - await base.OnRequestAsync(sender, e); - return; - } - - var batch = JsonSerializer.Deserialize(e.Session.HttpClient.Request.BodyString, ProxyUtils.JsonSerializerOptions); - if (batch == null) - { - await base.OnRequestAsync(sender, e); - return; - } - - var responses = new List(); - foreach (var request in batch.Requests) - { - GraphBatchResponsePayloadResponse? response = null; - var requestId = Guid.NewGuid().ToString(); - var requestDate = DateTime.Now.ToString(); - var headers = ProxyUtils - .BuildGraphResponseHeaders(e.Session.HttpClient.Request, requestId, requestDate); - - if (e.SessionData.TryGetValue(nameof(RateLimitingPlugin), out var pluginData) && - pluginData is List rateLimitingHeaders) - { - ProxyUtils.MergeHeaders(headers, rateLimitingHeaders); - } - - var mockResponse = GetMatchingMockResponse(request, e.Session.HttpClient.Request.RequestUri); - if (mockResponse == null) - { - response = new GraphBatchResponsePayloadResponse - { - Id = request.Id, - Status = (int)HttpStatusCode.BadGateway, - Headers = headers.ToDictionary(h => h.Name, h => h.Value), - Body = new GraphBatchResponsePayloadResponseBody - { - Error = new GraphBatchResponsePayloadResponseBodyError - { - Code = "BadGateway", - Message = "No mock response found for this request" - } - } - }; - - Logger.LogRequest([$"502 {request.Url}"], MessageType.Mocked, new LoggingContext(e.Session)); - } - else - { - dynamic? body = null; - var statusCode = HttpStatusCode.OK; - if (mockResponse.Response?.StatusCode is not null) - { - statusCode = (HttpStatusCode)mockResponse.Response.StatusCode; - } - - if (mockResponse.Response?.Headers is not null) - { - ProxyUtils.MergeHeaders(headers, mockResponse.Response.Headers); - } - - // default the content type to application/json unless set in the mock response - if (!headers.Any(h => h.Name.Equals("content-type", StringComparison.OrdinalIgnoreCase))) - { - headers.Add(new("content-type", "application/json")); - } - - if (mockResponse.Response?.Body is not null) - { - var bodyString = JsonSerializer.Serialize(mockResponse.Response.Body, ProxyUtils.JsonSerializerOptions) as string; - // we get a JSON string so need to start with the opening quote - if (bodyString?.StartsWith("\"@") ?? false) - { - // we've got a mock body starting with @-token which means we're sending - // a response from a file on disk - // if we can read the file, we can immediately send the response and - // skip the rest of the logic in this method - // remove the surrounding quotes and the @-token - var filePath = Path.Combine(Path.GetDirectoryName(_configuration.MocksFile) ?? "", ProxyUtils.ReplacePathTokens(bodyString.Trim('"').Substring(1))); - if (!File.Exists(filePath)) - { - Logger.LogError("File {filePath} not found. Serving file path in the mock response", filePath); - body = bodyString; - } - else - { - var bodyBytes = File.ReadAllBytes(filePath); - body = Convert.ToBase64String(bodyBytes); - } - } - else - { - body = mockResponse.Response.Body; - } - } - response = new GraphBatchResponsePayloadResponse - { - Id = request.Id, - Status = (int)statusCode, - Headers = headers.ToDictionary(h => h.Name, h => h.Value), - Body = body - }; - - Logger.LogRequest([$"{mockResponse.Response?.StatusCode ?? 200} {mockResponse.Request?.Url}"], MessageType.Mocked, new LoggingContext(e.Session)); - } - - responses.Add(response); - } - - var batchRequestId = Guid.NewGuid().ToString(); - var batchRequestDate = DateTime.Now.ToString(); - var batchHeaders = ProxyUtils.BuildGraphResponseHeaders(e.Session.HttpClient.Request, batchRequestId, batchRequestDate); - var batchResponse = new GraphBatchResponsePayload - { - Responses = [.. responses] - }; - var batchResponseString = JsonSerializer.Serialize(batchResponse, ProxyUtils.JsonSerializerOptions); - ProcessMockResponse(ref batchResponseString, batchHeaders, e, null); - e.Session.GenericResponse(batchResponseString ?? string.Empty, HttpStatusCode.OK, batchHeaders.Select(h => new HttpHeader(h.Name, h.Value))); - Logger.LogRequest([$"200 {e.Session.HttpClient.Request.RequestUri}"], MessageType.Mocked, new LoggingContext(e.Session)); - e.ResponseState.HasBeenSet = true; - } - - protected MockResponse? GetMatchingMockResponse(GraphBatchRequestPayloadRequest request, Uri batchRequestUri) - { - if (_configuration.NoMocks || - _configuration.Mocks is null || - !_configuration.Mocks.Any()) - { - return null; - } - - var mockResponse = _configuration.Mocks.FirstOrDefault(mockResponse => - { - if (mockResponse.Request?.Method != request.Method) return false; - // URLs in batch are relative to Graph version number so we need - // to make them absolute using the batch request URL - var absoluteRequestFromBatchUrl = ProxyUtils - .GetAbsoluteRequestUrlFromBatch(batchRequestUri, request.Url) - .ToString(); - if (mockResponse.Request.Url == absoluteRequestFromBatchUrl) - { - return true; - } - - // check if the URL contains a wildcard - // if it doesn't, it's not a match for the current request for sure - if (!mockResponse.Request.Url.Contains('*')) - { - return false; - } - - //turn mock URL with wildcard into a regex and match against the request URL - var mockResponseUrlRegex = Regex.Escape(mockResponse.Request.Url).Replace("\\*", ".*"); - return Regex.IsMatch(absoluteRequestFromBatchUrl, $"^{mockResponseUrlRegex}$"); - }); - return mockResponse; - } +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Text.Json; +using System.Text.RegularExpressions; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Plugins.Behavior; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Titanium.Web.Proxy.Models; + +namespace Microsoft.DevProxy.Plugins.Mocks; + +public class GraphMockResponsePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : MockResponsePlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(GraphMockResponsePlugin); + + protected override async Task OnRequestAsync(object? sender, ProxyRequestArgs e) + { + if (_configuration.NoMocks) + { + Logger.LogRequest("Mocks are disabled", MessageType.Skipped, new LoggingContext(e.Session)); + // mocking has been disabled. Nothing to do + return; + } + + if (!ProxyUtils.IsGraphBatchUrl(e.Session.HttpClient.Request.RequestUri)) + { + // not a batch request, use the basic mock functionality + await base.OnRequestAsync(sender, e); + return; + } + + var batch = JsonSerializer.Deserialize(e.Session.HttpClient.Request.BodyString, ProxyUtils.JsonSerializerOptions); + if (batch == null) + { + await base.OnRequestAsync(sender, e); + return; + } + + var responses = new List(); + foreach (var request in batch.Requests) + { + GraphBatchResponsePayloadResponse? response = null; + var requestId = Guid.NewGuid().ToString(); + var requestDate = DateTime.Now.ToString(); + var headers = ProxyUtils + .BuildGraphResponseHeaders(e.Session.HttpClient.Request, requestId, requestDate); + + if (e.SessionData.TryGetValue(nameof(RateLimitingPlugin), out var pluginData) && + pluginData is List rateLimitingHeaders) + { + ProxyUtils.MergeHeaders(headers, rateLimitingHeaders); + } + + var mockResponse = GetMatchingMockResponse(request, e.Session.HttpClient.Request.RequestUri); + if (mockResponse == null) + { + response = new GraphBatchResponsePayloadResponse + { + Id = request.Id, + Status = (int)HttpStatusCode.BadGateway, + Headers = headers.ToDictionary(h => h.Name, h => h.Value), + Body = new GraphBatchResponsePayloadResponseBody + { + Error = new GraphBatchResponsePayloadResponseBodyError + { + Code = "BadGateway", + Message = "No mock response found for this request" + } + } + }; + + Logger.LogRequest($"502 {request.Url}", MessageType.Mocked, new LoggingContext(e.Session)); + } + else + { + dynamic? body = null; + var statusCode = HttpStatusCode.OK; + if (mockResponse.Response?.StatusCode is not null) + { + statusCode = (HttpStatusCode)mockResponse.Response.StatusCode; + } + + if (mockResponse.Response?.Headers is not null) + { + ProxyUtils.MergeHeaders(headers, mockResponse.Response.Headers); + } + + // default the content type to application/json unless set in the mock response + if (!headers.Any(h => h.Name.Equals("content-type", StringComparison.OrdinalIgnoreCase))) + { + headers.Add(new("content-type", "application/json")); + } + + if (mockResponse.Response?.Body is not null) + { + var bodyString = JsonSerializer.Serialize(mockResponse.Response.Body, ProxyUtils.JsonSerializerOptions) as string; + // we get a JSON string so need to start with the opening quote + if (bodyString?.StartsWith("\"@") ?? false) + { + // we've got a mock body starting with @-token which means we're sending + // a response from a file on disk + // if we can read the file, we can immediately send the response and + // skip the rest of the logic in this method + // remove the surrounding quotes and the @-token + var filePath = Path.Combine(Path.GetDirectoryName(_configuration.MocksFile) ?? "", ProxyUtils.ReplacePathTokens(bodyString.Trim('"').Substring(1))); + if (!File.Exists(filePath)) + { + Logger.LogError("File {filePath} not found. Serving file path in the mock response", filePath); + body = bodyString; + } + else + { + var bodyBytes = File.ReadAllBytes(filePath); + body = Convert.ToBase64String(bodyBytes); + } + } + else + { + body = mockResponse.Response.Body; + } + } + response = new GraphBatchResponsePayloadResponse + { + Id = request.Id, + Status = (int)statusCode, + Headers = headers.ToDictionary(h => h.Name, h => h.Value), + Body = body + }; + + Logger.LogRequest($"{mockResponse.Response?.StatusCode ?? 200} {mockResponse.Request?.Url}", MessageType.Mocked, new LoggingContext(e.Session)); + } + + responses.Add(response); + } + + var batchRequestId = Guid.NewGuid().ToString(); + var batchRequestDate = DateTime.Now.ToString(); + var batchHeaders = ProxyUtils.BuildGraphResponseHeaders(e.Session.HttpClient.Request, batchRequestId, batchRequestDate); + var batchResponse = new GraphBatchResponsePayload + { + Responses = [.. responses] + }; + var batchResponseString = JsonSerializer.Serialize(batchResponse, ProxyUtils.JsonSerializerOptions); + ProcessMockResponse(ref batchResponseString, batchHeaders, e, null); + e.Session.GenericResponse(batchResponseString ?? string.Empty, HttpStatusCode.OK, batchHeaders.Select(h => new HttpHeader(h.Name, h.Value))); + Logger.LogRequest($"200 {e.Session.HttpClient.Request.RequestUri}", MessageType.Mocked, new LoggingContext(e.Session)); + e.ResponseState.HasBeenSet = true; + } + + protected MockResponse? GetMatchingMockResponse(GraphBatchRequestPayloadRequest request, Uri batchRequestUri) + { + if (_configuration.NoMocks || + _configuration.Mocks is null || + !_configuration.Mocks.Any()) + { + return null; + } + + var mockResponse = _configuration.Mocks.FirstOrDefault(mockResponse => + { + if (mockResponse.Request?.Method != request.Method) return false; + // URLs in batch are relative to Graph version number so we need + // to make them absolute using the batch request URL + var absoluteRequestFromBatchUrl = ProxyUtils + .GetAbsoluteRequestUrlFromBatch(batchRequestUri, request.Url) + .ToString(); + if (mockResponse.Request.Url == absoluteRequestFromBatchUrl) + { + return true; + } + + // check if the URL contains a wildcard + // if it doesn't, it's not a match for the current request for sure + if (!mockResponse.Request.Url.Contains('*')) + { + return false; + } + + //turn mock URL with wildcard into a regex and match against the request URL + var mockResponseUrlRegex = Regex.Escape(mockResponse.Request.Url).Replace("\\*", ".*"); + return Regex.IsMatch(absoluteRequestFromBatchUrl, $"^{mockResponseUrlRegex}$"); + }); + return mockResponse; + } } \ No newline at end of file diff --git a/dev-proxy-plugins/Mocks/MockRequestPlugin.cs b/dev-proxy-plugins/Mocks/MockRequestPlugin.cs index e3ea80a3..3fd8fabb 100644 --- a/dev-proxy-plugins/Mocks/MockRequestPlugin.cs +++ b/dev-proxy-plugins/Mocks/MockRequestPlugin.cs @@ -1,111 +1,111 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics; -using System.Text; -using System.Text.Json; -using System.Text.Json.Serialization; -using Microsoft.DevProxy.Abstractions; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.Mocks; - -public class MockRequestConfiguration -{ - [JsonIgnore] - public string MockFile { get; set; } = "mock-request.json"; - public MockRequest? Request { get; set; } -} - -public class MockRequestPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - protected MockRequestConfiguration _configuration = new(); - private MockRequestLoader? _loader = null; - - public override string Name => nameof(MockRequestPlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - _loader = new MockRequestLoader(Logger, _configuration); - - PluginEvents.MockRequest += OnMockRequestAsync; - - // make the mock file path relative to the configuration file - _configuration.MockFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.MockFile), Path.GetDirectoryName(Context.Configuration.ConfigFile ?? string.Empty) ?? string.Empty); - - // load the request from the configured mock file - _loader.InitResponsesWatcher(); - } - - protected HttpRequestMessage GetRequestMessage() - { - Debug.Assert(_configuration.Request is not null, "The mock request is not configured"); - - Logger.LogDebug("Preparing mock {method} request to {url}", _configuration.Request.Method, _configuration.Request.Url); - var requestMessage = new HttpRequestMessage - { - RequestUri = new Uri(_configuration.Request.Url), - Method = new HttpMethod(_configuration.Request.Method) - }; - - var contentType = ""; - if (_configuration.Request.Headers is not null) - { - Logger.LogDebug("Adding headers to the mock request"); - - foreach (var header in _configuration.Request.Headers) - { - if (header.Name.Equals("content-type", StringComparison.CurrentCultureIgnoreCase)) - { - contentType = header.Value; - continue; - } - - requestMessage.Headers.Add(header.Name, header.Value); - } - } - - if (_configuration.Request.Body is not null) - { - Logger.LogDebug("Adding body to the mock request"); - - if (_configuration.Request.Body is string) - { - requestMessage.Content = new StringContent(_configuration.Request.Body, Encoding.UTF8, contentType); - } - else - { - requestMessage.Content = new StringContent(JsonSerializer.Serialize(_configuration.Request.Body, ProxyUtils.JsonSerializerOptions), Encoding.UTF8, "application/json"); - } - } - - return requestMessage; - } - - protected virtual async Task OnMockRequestAsync(object sender, EventArgs e) - { - if (_configuration.Request is null) - { - Logger.LogDebug("No mock request is configured. Skipping."); - return; - } - - using var httpClient = new HttpClient(); - var requestMessage = GetRequestMessage(); - - try - { - Logger.LogRequest(["Sending mock request"], MessageType.Mocked, _configuration.Request.Method, _configuration.Request.Url); - - await httpClient.SendAsync(requestMessage); - } - catch (Exception ex) - { - Logger.LogError(ex, "An error has occurred while sending the mock request to {url}", _configuration.Request.Url); - } - } +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.DevProxy.Abstractions; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.Mocks; + +public class MockRequestConfiguration +{ + [JsonIgnore] + public string MockFile { get; set; } = "mock-request.json"; + public MockRequest? Request { get; set; } +} + +public class MockRequestPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + protected MockRequestConfiguration _configuration = new(); + private MockRequestLoader? _loader = null; + + public override string Name => nameof(MockRequestPlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + _loader = new MockRequestLoader(Logger, _configuration); + + PluginEvents.MockRequest += OnMockRequestAsync; + + // make the mock file path relative to the configuration file + _configuration.MockFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.MockFile), Path.GetDirectoryName(Context.Configuration.ConfigFile ?? string.Empty) ?? string.Empty); + + // load the request from the configured mock file + _loader.InitResponsesWatcher(); + } + + protected HttpRequestMessage GetRequestMessage() + { + Debug.Assert(_configuration.Request is not null, "The mock request is not configured"); + + Logger.LogDebug("Preparing mock {method} request to {url}", _configuration.Request.Method, _configuration.Request.Url); + var requestMessage = new HttpRequestMessage + { + RequestUri = new Uri(_configuration.Request.Url), + Method = new HttpMethod(_configuration.Request.Method) + }; + + var contentType = ""; + if (_configuration.Request.Headers is not null) + { + Logger.LogDebug("Adding headers to the mock request"); + + foreach (var header in _configuration.Request.Headers) + { + if (header.Name.Equals("content-type", StringComparison.CurrentCultureIgnoreCase)) + { + contentType = header.Value; + continue; + } + + requestMessage.Headers.Add(header.Name, header.Value); + } + } + + if (_configuration.Request.Body is not null) + { + Logger.LogDebug("Adding body to the mock request"); + + if (_configuration.Request.Body is string) + { + requestMessage.Content = new StringContent(_configuration.Request.Body, Encoding.UTF8, contentType); + } + else + { + requestMessage.Content = new StringContent(JsonSerializer.Serialize(_configuration.Request.Body, ProxyUtils.JsonSerializerOptions), Encoding.UTF8, "application/json"); + } + } + + return requestMessage; + } + + protected virtual async Task OnMockRequestAsync(object sender, EventArgs e) + { + if (_configuration.Request is null) + { + Logger.LogDebug("No mock request is configured. Skipping."); + return; + } + + using var httpClient = new HttpClient(); + var requestMessage = GetRequestMessage(); + + try + { + Logger.LogRequest("Sending mock request", MessageType.Mocked, _configuration.Request.Method, _configuration.Request.Url); + + await httpClient.SendAsync(requestMessage); + } + catch (Exception ex) + { + Logger.LogError(ex, "An error has occurred while sending the mock request to {url}", _configuration.Request.Url); + } + } } \ No newline at end of file diff --git a/dev-proxy-plugins/Mocks/MockResponsePlugin.cs b/dev-proxy-plugins/Mocks/MockResponsePlugin.cs index 315a3183..0b3a0d76 100644 --- a/dev-proxy-plugins/Mocks/MockResponsePlugin.cs +++ b/dev-proxy-plugins/Mocks/MockResponsePlugin.cs @@ -105,37 +105,49 @@ protected virtual Task OnRequestAsync(object? sender, ProxyRequestArgs e) { Request request = e.Session.HttpClient.Request; ResponseState state = e.ResponseState; - if (!_configuration.NoMocks && UrlsToWatch is not null && e.ShouldExecute(UrlsToWatch)) + if (_configuration.NoMocks) { - var matchingResponse = GetMatchingMockResponse(request); - if (matchingResponse is not null) - { - ProcessMockResponseInternal(e, matchingResponse); - state.HasBeenSet = true; - } - else if (_configuration.BlockUnmockedRequests) + Logger.LogRequest("Mocks disabled", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (UrlsToWatch is null || + !e.ShouldExecute(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + var matchingResponse = GetMatchingMockResponse(request); + if (matchingResponse is not null) + { + ProcessMockResponseInternal(e, matchingResponse); + state.HasBeenSet = true; + return Task.CompletedTask; + } + else if (_configuration.BlockUnmockedRequests) + { + ProcessMockResponseInternal(e, new MockResponse { - ProcessMockResponseInternal(e, new MockResponse + Request = new() { - Request = new() - { - Url = request.Url, - Method = request.Method ?? "" - }, - Response = new() + Url = request.Url, + Method = request.Method ?? "" + }, + Response = new() + { + StatusCode = 502, + Body = new GraphErrorResponseBody(new GraphErrorResponseError { - StatusCode = 502, - Body = new GraphErrorResponseBody(new GraphErrorResponseError - { - Code = "Bad Gateway", - Message = $"No mock response found for {request.Method} {request.Url}" - }) - } - }); - state.HasBeenSet = true; - } + Code = "Bad Gateway", + Message = $"No mock response found for {request.Method} {request.Url}" + }) + } + }); + state.HasBeenSet = true; + return Task.CompletedTask; } + Logger.LogRequest("No matching mock response found", MessageType.Skipped, new LoggingContext(e.Session)); return Task.CompletedTask; } @@ -294,7 +306,7 @@ private void ProcessMockResponseInternal(ProxyRequestArgs e, MockResponse matchi var bodyBytes = File.ReadAllBytes(filePath); ProcessMockResponse(ref bodyBytes, headers, e, matchingResponse); e.Session.GenericResponse(bodyBytes, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); - Logger.LogRequest([$"{matchingResponse.Response.StatusCode ?? 200} {matchingResponse.Request?.Url}"], MessageType.Mocked, new LoggingContext(e.Session)); + Logger.LogRequest($"{matchingResponse.Response.StatusCode ?? 200} {matchingResponse.Request?.Url}", MessageType.Mocked, new LoggingContext(e.Session)); return; } } @@ -303,7 +315,8 @@ private void ProcessMockResponseInternal(ProxyRequestArgs e, MockResponse matchi body = bodyString; } } - else { + else + { // we need to remove the content-type header if the body is empty // some clients fail on empty body + content-type var contentTypeHeader = headers.FirstOrDefault(h => h.Name.Equals("content-type", StringComparison.OrdinalIgnoreCase)); @@ -315,6 +328,6 @@ private void ProcessMockResponseInternal(ProxyRequestArgs e, MockResponse matchi ProcessMockResponse(ref body, headers, e, matchingResponse); e.Session.GenericResponse(body ?? string.Empty, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); - Logger.LogRequest([$"{matchingResponse.Response?.StatusCode ?? 200} {matchingResponse.Request?.Url}"], MessageType.Mocked, new LoggingContext(e.Session)); + Logger.LogRequest($"{matchingResponse.Response?.StatusCode ?? 200} {matchingResponse.Request?.Url}", MessageType.Mocked, new LoggingContext(e.Session)); } } diff --git a/dev-proxy-plugins/Mocks/OpenAIMockResponsePlugin.cs b/dev-proxy-plugins/Mocks/OpenAIMockResponsePlugin.cs index 50839cd5..b493c569 100644 --- a/dev-proxy-plugins/Mocks/OpenAIMockResponsePlugin.cs +++ b/dev-proxy-plugins/Mocks/OpenAIMockResponsePlugin.cs @@ -1,369 +1,369 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Net; -using System.Text.Json; -using System.Text.Json.Serialization; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Abstractions.LanguageModel; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Titanium.Web.Proxy.Models; - -namespace Microsoft.DevProxy.Plugins.Mocks; - -public class OpenAIMockResponsePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(OpenAIMockResponsePlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - using var scope = Logger.BeginScope(Name); - - Logger.LogInformation("Checking language model availability..."); - if (!await Context.LanguageModelClient.IsEnabledAsync()) - { - Logger.LogError("Local language model is not enabled. The {plugin} will not be used.", Name); - return; - } - - PluginEvents.BeforeRequest += OnRequestAsync; - } - - private async Task OnRequestAsync(object sender, ProxyRequestArgs e) - { - using var scope = Logger.BeginScope(Name); - - var request = e.Session.HttpClient.Request; - if (request.Method is null || - !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || - !request.HasBody) - { - return; - } - - if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest)) - { - return; - } - - if (openAiRequest is OpenAICompletionRequest completionRequest) - { - if ((await Context.LanguageModelClient.GenerateCompletionAsync(completionRequest.Prompt)) is not OllamaLanguageModelCompletionResponse ollamaResponse) - { - return; - } - if (ollamaResponse.Error is not null) - { - Logger.LogError("Error from Ollama language model: {error}", ollamaResponse.Error); - return; - } - - var openAiResponse = ollamaResponse.ConvertToOpenAIResponse(); - SendMockResponse(openAiResponse, ollamaResponse.RequestUrl, e); - } - else if (openAiRequest is OpenAIChatCompletionRequest chatRequest) - { - if ((await Context.LanguageModelClient - .GenerateChatCompletionAsync(chatRequest.Messages.ConvertToLanguageModelChatCompletionMessage())) is not OllamaLanguageModelChatCompletionResponse ollamaResponse) - { - return; - } - if (ollamaResponse.Error is not null) - { - Logger.LogError("Error from Ollama language model: {error}", ollamaResponse.Error); - return; - } - - var openAiResponse = ollamaResponse.ConvertToOpenAIResponse(); - SendMockResponse(openAiResponse, ollamaResponse.RequestUrl, e); - } - else - { - Logger.LogError("Unknown OpenAI request type."); - } - } - - private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) - { - request = null; - - if (string.IsNullOrEmpty(content)) - { - return false; - } - - try - { - Logger.LogDebug("Checking if the request is an OpenAI request..."); - - var rawRequest = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - - if (rawRequest.TryGetProperty("prompt", out _)) - { - Logger.LogDebug("Request is a completion request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - if (rawRequest.TryGetProperty("messages", out _)) - { - Logger.LogDebug("Request is a chat completion request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - Logger.LogDebug("Request is not an OpenAI request."); - return false; - } - catch (JsonException ex) - { - Logger.LogDebug(ex, "Failed to deserialize OpenAI request."); - return false; - } - } - - private void SendMockResponse(OpenAIResponse response, string localLmUrl, ProxyRequestArgs e) where TResponse : OpenAIResponse - { - e.Session.GenericResponse( - // we need this cast or else the JsonSerializer drops derived properties - JsonSerializer.Serialize((TResponse)response, ProxyUtils.JsonSerializerOptions), - HttpStatusCode.OK, - [ - new HttpHeader("content-type", "application/json"), - new HttpHeader("access-control-allow-origin", "*") - ] - ); - e.ResponseState.HasBeenSet = true; - Logger.LogRequest([$"200 {localLmUrl}"], MessageType.Mocked, new LoggingContext(e.Session)); - } -} - -#region models - -internal abstract class OpenAIRequest -{ - [JsonPropertyName("frequency_penalty")] - public long FrequencyPenalty { get; set; } - [JsonPropertyName("max_tokens")] - public long MaxTokens { get; set; } - [JsonPropertyName("presence_penalty")] - public long PresencePenalty { get; set; } - public object? Stop { get; set; } - public bool Stream { get; set; } - public long Temperature { get; set; } - [JsonPropertyName("top_p")] - public double TopP { get; set; } -} - -internal abstract class OpenAIResponse -{ - public long Created { get; set; } - public string Id { get; set; } = string.Empty; - public string Model { get; set; } = string.Empty; - public string Object { get; set; } = "text_completion"; - [JsonPropertyName("prompt_filter_results")] - public OpenAIResponsePromptFilterResult[] PromptFilterResults { get; set; } = []; - public OpenAIResponseUsage Usage { get; set; } = new(); -} - -internal abstract class OpenAIResponse : OpenAIResponse -{ - public TChoice[] Choices { get; set; } = []; -} - -internal class OpenAIResponseUsage -{ - [JsonPropertyName("completion_tokens")] - public long CompletionTokens { get; set; } - [JsonPropertyName("prompt_tokens")] - public long PromptTokens { get; set; } - [JsonPropertyName("total_tokens")] - public long TotalTokens { get; set; } -} - -internal abstract class OpenAIResponseChoice -{ - [JsonPropertyName("content_filter_results")] - public Dictionary ContentFilterResults { get; set; } = new(); - [JsonPropertyName("finish_reason")] - public string FinishReason { get; set; } = "length"; - public long Index { get; set; } - [JsonIgnore(Condition = JsonIgnoreCondition.Never)] - public object? Logprobs { get; set; } -} - -internal class OpenAIResponsePromptFilterResult -{ - [JsonPropertyName("content_filter_results")] - public Dictionary ContentFilterResults { get; set; } = new(); - [JsonPropertyName("prompt_index")] - public long PromptIndex { get; set; } -} - -internal class OpenAIResponseContentFilterResult -{ - public bool Filtered { get; set; } - public string Severity { get; set; } = "safe"; -} - -internal class OpenAICompletionRequest : OpenAIRequest -{ - public string Prompt { get; set; } = string.Empty; -} - -internal class OpenAICompletionResponse : OpenAIResponse -{ -} - -internal class OpenAICompletionResponseChoice : OpenAIResponseChoice -{ - public string Text { get; set; } = string.Empty; -} - -internal class OpenAIChatCompletionRequest : OpenAIRequest -{ - public OpenAIChatMessage[] Messages { get; set; } = []; -} - -internal class OpenAIChatMessage -{ - public string Content { get; set; } = string.Empty; - public string Role { get; set; } = string.Empty; -} - -internal class OpenAIChatCompletionResponse : OpenAIResponse -{ -} - -internal class OpenAIChatCompletionResponseChoice : OpenAIResponseChoice -{ - public OpenAIChatCompletionResponseChoiceMessage Message { get; set; } = new(); -} - -internal class OpenAIChatCompletionResponseChoiceMessage -{ - public string Content { get; set; } = string.Empty; - public string Role { get; set; } = string.Empty; -} - -#endregion - -#region extensions - -internal static class OllamaLanguageModelCompletionResponseExtensions -{ - public static OpenAICompletionResponse ConvertToOpenAIResponse(this OllamaLanguageModelCompletionResponse response) - { - return new OpenAICompletionResponse - { - Id = Guid.NewGuid().ToString(), - Object = "text_completion", - Created = ((DateTimeOffset)response.CreatedAt).ToUnixTimeSeconds(), - Model = response.Model, - PromptFilterResults = - [ - new OpenAIResponsePromptFilterResult - { - PromptIndex = 0, - ContentFilterResults = new Dictionary - { - { "hate", new() { Filtered = false, Severity = "safe" } }, - { "self_harm", new() { Filtered = false, Severity = "safe" } }, - { "sexual", new() { Filtered = false, Severity = "safe" } }, - { "violence", new() { Filtered = false, Severity = "safe" } } - } - } - ], - Choices = - [ - new OpenAICompletionResponseChoice - { - Text = response.Response ?? string.Empty, - Index = 0, - FinishReason = "length", - ContentFilterResults = new Dictionary - { - { "hate", new() { Filtered = false, Severity = "safe" } }, - { "self_harm", new() { Filtered = false, Severity = "safe" } }, - { "sexual", new() { Filtered = false, Severity = "safe" } }, - { "violence", new() { Filtered = false, Severity = "safe" } } - } - } - ], - Usage = new OpenAIResponseUsage - { - PromptTokens = response.PromptEvalCount, - CompletionTokens = response.EvalCount, - TotalTokens = response.PromptEvalCount + response.EvalCount - } - }; - } -} - -internal static class OllamaLanguageModelChatCompletionResponseExtensions -{ - public static OpenAIChatCompletionResponse ConvertToOpenAIResponse(this OllamaLanguageModelChatCompletionResponse response) - { - return new OpenAIChatCompletionResponse - { - Choices = [new OpenAIChatCompletionResponseChoice - { - ContentFilterResults = new Dictionary - { - { "hate", new() { Filtered = false, Severity = "safe" } }, - { "self_harm", new() { Filtered = false, Severity = "safe" } }, - { "sexual", new() { Filtered = false, Severity = "safe" } }, - { "violence", new() { Filtered = false, Severity = "safe" } } - }, - FinishReason = "stop", - Index = 0, - Message = new() - { - Content = response.Message.Content, - Role = response.Message.Role - } - }], - Created = ((DateTimeOffset)response.CreatedAt).ToUnixTimeSeconds(), - Id = Guid.NewGuid().ToString(), - Model = response.Model, - Object = "chat.completion", - PromptFilterResults = - [ - new OpenAIResponsePromptFilterResult - { - PromptIndex = 0, - ContentFilterResults = new Dictionary - { - { "hate", new() { Filtered = false, Severity = "safe" } }, - { "self_harm", new() { Filtered = false, Severity = "safe" } }, - { "sexual", new() { Filtered = false, Severity = "safe" } }, - { "violence", new() { Filtered = false, Severity = "safe" } } - } - } - ], - Usage = new OpenAIResponseUsage - { - PromptTokens = response.PromptEvalCount, - CompletionTokens = response.EvalCount, - TotalTokens = response.PromptEvalCount + response.EvalCount - } - }; - } -} - -internal static class OpenAIChatMessageExtensions -{ - public static ILanguageModelChatCompletionMessage[] ConvertToLanguageModelChatCompletionMessage(this OpenAIChatMessage[] messages) - { - return messages.Select(m => new OllamaLanguageModelChatCompletionMessage - { - Content = m.Content, - Role = m.Role - }).ToArray(); - } -} - +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Abstractions.LanguageModel; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Titanium.Web.Proxy.Models; + +namespace Microsoft.DevProxy.Plugins.Mocks; + +public class OpenAIMockResponsePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(OpenAIMockResponsePlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + using var scope = Logger.BeginScope(Name); + + Logger.LogInformation("Checking language model availability..."); + if (!await Context.LanguageModelClient.IsEnabledAsync()) + { + Logger.LogError("Local language model is not enabled. The {plugin} will not be used.", Name); + return; + } + + PluginEvents.BeforeRequest += OnRequestAsync; + } + + private async Task OnRequestAsync(object sender, ProxyRequestArgs e) + { + var request = e.Session.HttpClient.Request; + if (request.Method is null || + !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || + !request.HasBody) + { + Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest)) + { + Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + if (openAiRequest is OpenAICompletionRequest completionRequest) + { + if ((await Context.LanguageModelClient.GenerateCompletionAsync(completionRequest.Prompt)) is not OllamaLanguageModelCompletionResponse ollamaResponse) + { + return; + } + if (ollamaResponse.Error is not null) + { + Logger.LogError("Error from Ollama language model: {error}", ollamaResponse.Error); + return; + } + + var openAiResponse = ollamaResponse.ConvertToOpenAIResponse(); + SendMockResponse(openAiResponse, ollamaResponse.RequestUrl, e); + } + else if (openAiRequest is OpenAIChatCompletionRequest chatRequest) + { + if ((await Context.LanguageModelClient + .GenerateChatCompletionAsync(chatRequest.Messages.ConvertToLanguageModelChatCompletionMessage())) is not OllamaLanguageModelChatCompletionResponse ollamaResponse) + { + return; + } + if (ollamaResponse.Error is not null) + { + Logger.LogError("Error from Ollama language model: {error}", ollamaResponse.Error); + return; + } + + var openAiResponse = ollamaResponse.ConvertToOpenAIResponse(); + SendMockResponse(openAiResponse, ollamaResponse.RequestUrl, e); + } + else + { + Logger.LogError("Unknown OpenAI request type."); + } + } + + private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) + { + request = null; + + if (string.IsNullOrEmpty(content)) + { + return false; + } + + try + { + Logger.LogDebug("Checking if the request is an OpenAI request..."); + + var rawRequest = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + + if (rawRequest.TryGetProperty("prompt", out _)) + { + Logger.LogDebug("Request is a completion request"); + request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + return true; + } + + if (rawRequest.TryGetProperty("messages", out _)) + { + Logger.LogDebug("Request is a chat completion request"); + request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + return true; + } + + Logger.LogDebug("Request is not an OpenAI request."); + return false; + } + catch (JsonException ex) + { + Logger.LogDebug(ex, "Failed to deserialize OpenAI request."); + return false; + } + } + + private void SendMockResponse(OpenAIResponse response, string localLmUrl, ProxyRequestArgs e) where TResponse : OpenAIResponse + { + e.Session.GenericResponse( + // we need this cast or else the JsonSerializer drops derived properties + JsonSerializer.Serialize((TResponse)response, ProxyUtils.JsonSerializerOptions), + HttpStatusCode.OK, + [ + new HttpHeader("content-type", "application/json"), + new HttpHeader("access-control-allow-origin", "*") + ] + ); + e.ResponseState.HasBeenSet = true; + Logger.LogRequest($"200 {localLmUrl}", MessageType.Mocked, new LoggingContext(e.Session)); + } +} + +#region models + +internal abstract class OpenAIRequest +{ + [JsonPropertyName("frequency_penalty")] + public long FrequencyPenalty { get; set; } + [JsonPropertyName("max_tokens")] + public long MaxTokens { get; set; } + [JsonPropertyName("presence_penalty")] + public long PresencePenalty { get; set; } + public object? Stop { get; set; } + public bool Stream { get; set; } + public long Temperature { get; set; } + [JsonPropertyName("top_p")] + public double TopP { get; set; } +} + +internal abstract class OpenAIResponse +{ + public long Created { get; set; } + public string Id { get; set; } = string.Empty; + public string Model { get; set; } = string.Empty; + public string Object { get; set; } = "text_completion"; + [JsonPropertyName("prompt_filter_results")] + public OpenAIResponsePromptFilterResult[] PromptFilterResults { get; set; } = []; + public OpenAIResponseUsage Usage { get; set; } = new(); +} + +internal abstract class OpenAIResponse : OpenAIResponse +{ + public TChoice[] Choices { get; set; } = []; +} + +internal class OpenAIResponseUsage +{ + [JsonPropertyName("completion_tokens")] + public long CompletionTokens { get; set; } + [JsonPropertyName("prompt_tokens")] + public long PromptTokens { get; set; } + [JsonPropertyName("total_tokens")] + public long TotalTokens { get; set; } +} + +internal abstract class OpenAIResponseChoice +{ + [JsonPropertyName("content_filter_results")] + public Dictionary ContentFilterResults { get; set; } = new(); + [JsonPropertyName("finish_reason")] + public string FinishReason { get; set; } = "length"; + public long Index { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.Never)] + public object? Logprobs { get; set; } +} + +internal class OpenAIResponsePromptFilterResult +{ + [JsonPropertyName("content_filter_results")] + public Dictionary ContentFilterResults { get; set; } = new(); + [JsonPropertyName("prompt_index")] + public long PromptIndex { get; set; } +} + +internal class OpenAIResponseContentFilterResult +{ + public bool Filtered { get; set; } + public string Severity { get; set; } = "safe"; +} + +internal class OpenAICompletionRequest : OpenAIRequest +{ + public string Prompt { get; set; } = string.Empty; +} + +internal class OpenAICompletionResponse : OpenAIResponse +{ +} + +internal class OpenAICompletionResponseChoice : OpenAIResponseChoice +{ + public string Text { get; set; } = string.Empty; +} + +internal class OpenAIChatCompletionRequest : OpenAIRequest +{ + public OpenAIChatMessage[] Messages { get; set; } = []; +} + +internal class OpenAIChatMessage +{ + public string Content { get; set; } = string.Empty; + public string Role { get; set; } = string.Empty; +} + +internal class OpenAIChatCompletionResponse : OpenAIResponse +{ +} + +internal class OpenAIChatCompletionResponseChoice : OpenAIResponseChoice +{ + public OpenAIChatCompletionResponseChoiceMessage Message { get; set; } = new(); +} + +internal class OpenAIChatCompletionResponseChoiceMessage +{ + public string Content { get; set; } = string.Empty; + public string Role { get; set; } = string.Empty; +} + +#endregion + +#region extensions + +internal static class OllamaLanguageModelCompletionResponseExtensions +{ + public static OpenAICompletionResponse ConvertToOpenAIResponse(this OllamaLanguageModelCompletionResponse response) + { + return new OpenAICompletionResponse + { + Id = Guid.NewGuid().ToString(), + Object = "text_completion", + Created = ((DateTimeOffset)response.CreatedAt).ToUnixTimeSeconds(), + Model = response.Model, + PromptFilterResults = + [ + new OpenAIResponsePromptFilterResult + { + PromptIndex = 0, + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + } + } + ], + Choices = + [ + new OpenAICompletionResponseChoice + { + Text = response.Response ?? string.Empty, + Index = 0, + FinishReason = "length", + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + } + } + ], + Usage = new OpenAIResponseUsage + { + PromptTokens = response.PromptEvalCount, + CompletionTokens = response.EvalCount, + TotalTokens = response.PromptEvalCount + response.EvalCount + } + }; + } +} + +internal static class OllamaLanguageModelChatCompletionResponseExtensions +{ + public static OpenAIChatCompletionResponse ConvertToOpenAIResponse(this OllamaLanguageModelChatCompletionResponse response) + { + return new OpenAIChatCompletionResponse + { + Choices = [new OpenAIChatCompletionResponseChoice + { + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + }, + FinishReason = "stop", + Index = 0, + Message = new() + { + Content = response.Message.Content, + Role = response.Message.Role + } + }], + Created = ((DateTimeOffset)response.CreatedAt).ToUnixTimeSeconds(), + Id = Guid.NewGuid().ToString(), + Model = response.Model, + Object = "chat.completion", + PromptFilterResults = + [ + new OpenAIResponsePromptFilterResult + { + PromptIndex = 0, + ContentFilterResults = new Dictionary + { + { "hate", new() { Filtered = false, Severity = "safe" } }, + { "self_harm", new() { Filtered = false, Severity = "safe" } }, + { "sexual", new() { Filtered = false, Severity = "safe" } }, + { "violence", new() { Filtered = false, Severity = "safe" } } + } + } + ], + Usage = new OpenAIResponseUsage + { + PromptTokens = response.PromptEvalCount, + CompletionTokens = response.EvalCount, + TotalTokens = response.PromptEvalCount + response.EvalCount + } + }; + } +} + +internal static class OpenAIChatMessageExtensions +{ + public static ILanguageModelChatCompletionMessage[] ConvertToLanguageModelChatCompletionMessage(this OpenAIChatMessage[] messages) + { + return messages.Select(m => new OllamaLanguageModelChatCompletionMessage + { + Content = m.Content, + Role = m.Role + }).ToArray(); + } +} + #endregion \ No newline at end of file diff --git a/dev-proxy-plugins/OpenApi/OpenApiDocumentExtensions.cs b/dev-proxy-plugins/OpenApi/OpenApiDocumentExtensions.cs index 82d0a4de..0ce8f810 100644 --- a/dev-proxy-plugins/OpenApi/OpenApiDocumentExtensions.cs +++ b/dev-proxy-plugins/OpenApi/OpenApiDocumentExtensions.cs @@ -130,7 +130,7 @@ public static ApiPermissionsInfo CheckMinimalPermissions(this OpenApiDocument op foreach (var request in requests) { // get scopes from the token - var methodAndUrl = request.MessageLines.First(); + var methodAndUrl = request.Message; var methodAndUrlChunks = methodAndUrl.Split(' '); logger.LogDebug("Checking request {request}...", methodAndUrl); var (method, url) = (methodAndUrlChunks[0].ToUpper(), methodAndUrlChunks[1]); diff --git a/dev-proxy-plugins/RandomErrors/GenericRandomErrorPlugin.cs b/dev-proxy-plugins/RandomErrors/GenericRandomErrorPlugin.cs index 51de39c6..deb3fbf4 100644 --- a/dev-proxy-plugins/RandomErrors/GenericRandomErrorPlugin.cs +++ b/dev-proxy-plugins/RandomErrors/GenericRandomErrorPlugin.cs @@ -1,215 +1,227 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using System.Net; -using System.Text.Json; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; -using Microsoft.DevProxy.Plugins.Behavior; -using Microsoft.Extensions.Logging; -using System.Text.RegularExpressions; - -namespace Microsoft.DevProxy.Plugins.RandomErrors; -internal enum GenericRandomErrorFailMode -{ - Throttled, - Random, - PassThru -} - -public class GenericRandomErrorConfiguration -{ - public string? ErrorsFile { get; set; } - public int RetryAfterInSeconds { get; set; } = 5; - public IEnumerable Errors { get; set; } = []; -} - -public class GenericRandomErrorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private readonly GenericRandomErrorConfiguration _configuration = new(); - private GenericErrorResponsesLoader? _loader = null; - - public override string Name => nameof(GenericRandomErrorPlugin); - - private readonly Random _random = new(); - - // uses config to determine if a request should be failed - private GenericRandomErrorFailMode ShouldFail() => _random.Next(1, 100) <= Context.Configuration.Rate ? GenericRandomErrorFailMode.Random : GenericRandomErrorFailMode.PassThru; - - private void FailResponse(ProxyRequestArgs e) - { - var matchingResponse = GetMatchingErrorResponse(e.Session.HttpClient.Request); - if (matchingResponse is not null && - matchingResponse.Responses is not null) - { - // pick a random error response for the current request - var error = matchingResponse.Responses.ElementAt(_random.Next(0, matchingResponse.Responses.Length)); - UpdateProxyResponse(e, error); - } - } - - private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) - { - var throttleKeyForRequest = BuildThrottleKey(request); - return new ThrottlingInfo(throttleKeyForRequest == throttlingKey ? _configuration.RetryAfterInSeconds : 0, "Retry-After"); - } - - private GenericErrorResponse? GetMatchingErrorResponse(Request request) - { - if (_configuration.Errors is null || - !_configuration.Errors.Any()) - { - return null; - } - - var errorResponse = _configuration.Errors.FirstOrDefault(errorResponse => - { - if (errorResponse.Request is null) return false; - if (errorResponse.Responses is null) return false; - - if (errorResponse.Request.Method != request.Method) return false; - if (errorResponse.Request.Url == request.Url && - HasMatchingBody(errorResponse, request)) - { - return true; - } - - // check if the URL contains a wildcard - // if it doesn't, it's not a match for the current request for sure - if (!errorResponse.Request.Url.Contains('*')) - { - return false; - } - - // turn mock URL with wildcard into a regex and match against the request URL - var errorResponseUrlRegex = Regex.Escape(errorResponse.Request.Url).Replace("\\*", ".*"); - return Regex.IsMatch(request.Url, $"^{errorResponseUrlRegex}$") && - HasMatchingBody(errorResponse, request); - }); - - return errorResponse; - } - - private static bool HasMatchingBody(GenericErrorResponse errorResponse, Request request) - { - if (request.Method == "GET") - { - // GET requests don't have a body so we can't match on it - return true; - } - - if (errorResponse.Request?.BodyFragment is null) - { - // no body fragment to match on - return true; - } - - if (!request.HasBody || string.IsNullOrEmpty(request.BodyString)) - { - // error response defines a body fragment but the request has no body - // so it can't match - return false; - } - - return request.BodyString.Contains(errorResponse.Request.BodyFragment, StringComparison.OrdinalIgnoreCase); - } - - private void UpdateProxyResponse(ProxyRequestArgs e, GenericErrorResponseResponse error) - { - SessionEventArgs session = e.Session; - Request request = session.HttpClient.Request; - var headers = new List(); - if (error.Headers is not null) - { - headers.AddRange(error.Headers); - } - - if (error.StatusCode == (int)HttpStatusCode.TooManyRequests && - error.Headers is not null && - error.Headers.FirstOrDefault(h => h.Name == "Retry-After" || h.Name == "retry-after")?.Value == "@dynamic") - { - var retryAfterDate = DateTime.Now.AddSeconds(_configuration.RetryAfterInSeconds); - if (!e.GlobalData.ContainsKey(RetryAfterPlugin.ThrottledRequestsKey)) - { - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, new List()); - } - var throttledRequests = e.GlobalData[RetryAfterPlugin.ThrottledRequestsKey] as List; - throttledRequests?.Add(new ThrottlerInfo(BuildThrottleKey(request), ShouldThrottle, retryAfterDate)); - // replace the header with the @dynamic value with the actual value - var h = headers.First(h => h.Name == "Retry-After" || h.Name == "retry-after"); - headers.Remove(h); - headers.Add(new("Retry-After", _configuration.RetryAfterInSeconds.ToString())); - } - - var statusCode = (HttpStatusCode)(error.StatusCode ?? 400); - var body = error.Body is null ? string.Empty : JsonSerializer.Serialize(error.Body, ProxyUtils.JsonSerializerOptions); - // we get a JSON string so need to start with the opening quote - if (body.StartsWith("\"@")) - { - // we've got a mock body starting with @-token which means we're sending - // a response from a file on disk - // if we can read the file, we can immediately send the response and - // skip the rest of the logic in this method - // remove the surrounding quotes and the @-token - var filePath = Path.Combine(Path.GetDirectoryName(_configuration.ErrorsFile) ?? "", ProxyUtils.ReplacePathTokens(body.Trim('"').Substring(1))); - if (!File.Exists(filePath)) - { - Logger.LogError("File {filePath} not found. Serving file path in the mock response", (string?)filePath); - session.GenericResponse(body, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); - } - else - { - var bodyBytes = File.ReadAllBytes(filePath); - session.GenericResponse(bodyBytes, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); - } - } - else - { - session.GenericResponse(body, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); - } - e.ResponseState.HasBeenSet = true; - Logger.LogRequest([$"{error.StatusCode} {statusCode}"], MessageType.Chaos, new LoggingContext(e.Session)); - } - - // throttle requests per host - private static string BuildThrottleKey(Request r) => r.RequestUri.Host; - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - _configuration.ErrorsFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.ErrorsFile ?? string.Empty), Path.GetDirectoryName(Context.Configuration.ConfigFile ?? string.Empty) ?? string.Empty); - - _loader = new GenericErrorResponsesLoader(Logger, _configuration); - - PluginEvents.Init += OnInit; - PluginEvents.BeforeRequest += OnRequestAsync; - } - - private void OnInit(object? sender, InitArgs e) - { - _loader?.InitResponsesWatcher(); - } - - private Task OnRequestAsync(object? sender, ProxyRequestArgs e) - { - if (!e.ResponseState.HasBeenSet - && UrlsToWatch is not null - && e.ShouldExecute(UrlsToWatch)) - { - var failMode = ShouldFail(); - - if (failMode == GenericRandomErrorFailMode.PassThru && Context.Configuration?.Rate != 100) - { - return Task.CompletedTask; - } - FailResponse(e); - } - - return Task.CompletedTask; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using System.Net; +using System.Text.Json; +using Titanium.Web.Proxy.EventArguments; +using Titanium.Web.Proxy.Http; +using Titanium.Web.Proxy.Models; +using Microsoft.DevProxy.Plugins.Behavior; +using Microsoft.Extensions.Logging; +using System.Text.RegularExpressions; + +namespace Microsoft.DevProxy.Plugins.RandomErrors; +internal enum GenericRandomErrorFailMode +{ + Throttled, + Random, + PassThru +} + +public class GenericRandomErrorConfiguration +{ + public string? ErrorsFile { get; set; } + public int RetryAfterInSeconds { get; set; } = 5; + public IEnumerable Errors { get; set; } = []; +} + +public class GenericRandomErrorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private readonly GenericRandomErrorConfiguration _configuration = new(); + private GenericErrorResponsesLoader? _loader = null; + + public override string Name => nameof(GenericRandomErrorPlugin); + + private readonly Random _random = new(); + + // uses config to determine if a request should be failed + private GenericRandomErrorFailMode ShouldFail() => _random.Next(1, 100) <= Context.Configuration.Rate ? GenericRandomErrorFailMode.Random : GenericRandomErrorFailMode.PassThru; + + private void FailResponse(ProxyRequestArgs e) + { + var matchingResponse = GetMatchingErrorResponse(e.Session.HttpClient.Request); + if (matchingResponse is not null && + matchingResponse.Responses is not null) + { + // pick a random error response for the current request + var error = matchingResponse.Responses.ElementAt(_random.Next(0, matchingResponse.Responses.Length)); + UpdateProxyResponse(e, error); + } + else + { + Logger.LogRequest("No matching error response found", MessageType.Skipped, new LoggingContext(e.Session)); + } + } + + private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) + { + var throttleKeyForRequest = BuildThrottleKey(request); + return new ThrottlingInfo(throttleKeyForRequest == throttlingKey ? _configuration.RetryAfterInSeconds : 0, "Retry-After"); + } + + private GenericErrorResponse? GetMatchingErrorResponse(Request request) + { + if (_configuration.Errors is null || + !_configuration.Errors.Any()) + { + return null; + } + + var errorResponse = _configuration.Errors.FirstOrDefault(errorResponse => + { + if (errorResponse.Request is null) return false; + if (errorResponse.Responses is null) return false; + + if (errorResponse.Request.Method != request.Method) return false; + if (errorResponse.Request.Url == request.Url && + HasMatchingBody(errorResponse, request)) + { + return true; + } + + // check if the URL contains a wildcard + // if it doesn't, it's not a match for the current request for sure + if (!errorResponse.Request.Url.Contains('*')) + { + return false; + } + + // turn mock URL with wildcard into a regex and match against the request URL + var errorResponseUrlRegex = Regex.Escape(errorResponse.Request.Url).Replace("\\*", ".*"); + return Regex.IsMatch(request.Url, $"^{errorResponseUrlRegex}$") && + HasMatchingBody(errorResponse, request); + }); + + return errorResponse; + } + + private static bool HasMatchingBody(GenericErrorResponse errorResponse, Request request) + { + if (request.Method == "GET") + { + // GET requests don't have a body so we can't match on it + return true; + } + + if (errorResponse.Request?.BodyFragment is null) + { + // no body fragment to match on + return true; + } + + if (!request.HasBody || string.IsNullOrEmpty(request.BodyString)) + { + // error response defines a body fragment but the request has no body + // so it can't match + return false; + } + + return request.BodyString.Contains(errorResponse.Request.BodyFragment, StringComparison.OrdinalIgnoreCase); + } + + private void UpdateProxyResponse(ProxyRequestArgs e, GenericErrorResponseResponse error) + { + SessionEventArgs session = e.Session; + Request request = session.HttpClient.Request; + var headers = new List(); + if (error.Headers is not null) + { + headers.AddRange(error.Headers); + } + + if (error.StatusCode == (int)HttpStatusCode.TooManyRequests && + error.Headers is not null && + error.Headers.FirstOrDefault(h => h.Name == "Retry-After" || h.Name == "retry-after")?.Value == "@dynamic") + { + var retryAfterDate = DateTime.Now.AddSeconds(_configuration.RetryAfterInSeconds); + if (!e.GlobalData.ContainsKey(RetryAfterPlugin.ThrottledRequestsKey)) + { + e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, new List()); + } + var throttledRequests = e.GlobalData[RetryAfterPlugin.ThrottledRequestsKey] as List; + throttledRequests?.Add(new ThrottlerInfo(BuildThrottleKey(request), ShouldThrottle, retryAfterDate)); + // replace the header with the @dynamic value with the actual value + var h = headers.First(h => h.Name == "Retry-After" || h.Name == "retry-after"); + headers.Remove(h); + headers.Add(new("Retry-After", _configuration.RetryAfterInSeconds.ToString())); + } + + var statusCode = (HttpStatusCode)(error.StatusCode ?? 400); + var body = error.Body is null ? string.Empty : JsonSerializer.Serialize(error.Body, ProxyUtils.JsonSerializerOptions); + // we get a JSON string so need to start with the opening quote + if (body.StartsWith("\"@")) + { + // we've got a mock body starting with @-token which means we're sending + // a response from a file on disk + // if we can read the file, we can immediately send the response and + // skip the rest of the logic in this method + // remove the surrounding quotes and the @-token + var filePath = Path.Combine(Path.GetDirectoryName(_configuration.ErrorsFile) ?? "", ProxyUtils.ReplacePathTokens(body.Trim('"').Substring(1))); + if (!File.Exists(filePath)) + { + Logger.LogError("File {filePath} not found. Serving file path in the mock response", (string?)filePath); + session.GenericResponse(body, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); + } + else + { + var bodyBytes = File.ReadAllBytes(filePath); + session.GenericResponse(bodyBytes, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); + } + } + else + { + session.GenericResponse(body, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); + } + e.ResponseState.HasBeenSet = true; + Logger.LogRequest($"{error.StatusCode} {statusCode}", MessageType.Chaos, new LoggingContext(e.Session)); + } + + // throttle requests per host + private static string BuildThrottleKey(Request r) => r.RequestUri.Host; + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + _configuration.ErrorsFile = Path.GetFullPath(ProxyUtils.ReplacePathTokens(_configuration.ErrorsFile ?? string.Empty), Path.GetDirectoryName(Context.Configuration.ConfigFile ?? string.Empty) ?? string.Empty); + + _loader = new GenericErrorResponsesLoader(Logger, _configuration); + + PluginEvents.Init += OnInit; + PluginEvents.BeforeRequest += OnRequestAsync; + } + + private void OnInit(object? sender, InitArgs e) + { + _loader?.InitResponsesWatcher(); + } + + private Task OnRequestAsync(object? sender, ProxyRequestArgs e) + { + if (e.ResponseState.HasBeenSet) + { + Logger.LogRequest("Response already set", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (UrlsToWatch is null || + !e.ShouldExecute(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + var failMode = ShouldFail(); + + if (failMode == GenericRandomErrorFailMode.PassThru && Context.Configuration?.Rate != 100) + { + Logger.LogRequest("Pass through", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + FailResponse(e); + + return Task.CompletedTask; + } +} diff --git a/dev-proxy-plugins/RandomErrors/GraphRandomErrorPlugin.cs b/dev-proxy-plugins/RandomErrors/GraphRandomErrorPlugin.cs index 3068e62d..3f804077 100644 --- a/dev-proxy-plugins/RandomErrors/GraphRandomErrorPlugin.cs +++ b/dev-proxy-plugins/RandomErrors/GraphRandomErrorPlugin.cs @@ -1,280 +1,288 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using System.CommandLine; -using System.CommandLine.Invocation; -using System.Net; -using System.Text.Json; -using System.Text.RegularExpressions; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; -using Microsoft.DevProxy.Plugins.Behavior; - -namespace Microsoft.DevProxy.Plugins.RandomErrors; -internal enum GraphRandomErrorFailMode -{ - Random, - PassThru -} - -public class GraphRandomErrorConfiguration -{ - public List AllowedErrors { get; set; } = []; - public int RetryAfterInSeconds { get; set; } = 5; -} - -public class GraphRandomErrorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private static readonly string _allowedErrorsOptionName = "--allowed-errors"; - private readonly GraphRandomErrorConfiguration _configuration = new(); - - public override string Name => nameof(GraphRandomErrorPlugin); - - private readonly Dictionary _methodStatusCode = new() - { - { - "GET", new[] { - HttpStatusCode.TooManyRequests, - HttpStatusCode.InternalServerError, - HttpStatusCode.BadGateway, - HttpStatusCode.ServiceUnavailable, - HttpStatusCode.GatewayTimeout - } - }, - { - "POST", new[] { - HttpStatusCode.TooManyRequests, - HttpStatusCode.InternalServerError, - HttpStatusCode.BadGateway, - HttpStatusCode.ServiceUnavailable, - HttpStatusCode.GatewayTimeout, - HttpStatusCode.InsufficientStorage - } - }, - { - "PUT", new[] { - HttpStatusCode.TooManyRequests, - HttpStatusCode.InternalServerError, - HttpStatusCode.BadGateway, - HttpStatusCode.ServiceUnavailable, - HttpStatusCode.GatewayTimeout, - HttpStatusCode.InsufficientStorage - } - }, - { - "PATCH", new[] { - HttpStatusCode.TooManyRequests, - HttpStatusCode.InternalServerError, - HttpStatusCode.BadGateway, - HttpStatusCode.ServiceUnavailable, - HttpStatusCode.GatewayTimeout - } - }, - { - "DELETE", new[] { - HttpStatusCode.TooManyRequests, - HttpStatusCode.InternalServerError, - HttpStatusCode.BadGateway, - HttpStatusCode.ServiceUnavailable, - HttpStatusCode.GatewayTimeout, - HttpStatusCode.InsufficientStorage - } - } - }; - private readonly Random _random = new(); - - // uses config to determine if a request should be failed - private GraphRandomErrorFailMode ShouldFail(ProxyRequestArgs e) => _random.Next(1, 100) <= Context.Configuration.Rate ? GraphRandomErrorFailMode.Random : GraphRandomErrorFailMode.PassThru; - - private void FailResponse(ProxyRequestArgs e) - { - // pick a random error response for the current request method - var methodStatusCodes = _methodStatusCode[e.Session.HttpClient.Request.Method ?? "GET"]; - var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; - UpdateProxyResponse(e, errorStatus); - } - - private void FailBatch(ProxyRequestArgs e) - { - var batchResponse = new GraphBatchResponsePayload(); - - var batch = JsonSerializer.Deserialize(e.Session.HttpClient.Request.BodyString, ProxyUtils.JsonSerializerOptions); - if (batch == null) - { - UpdateProxyBatchResponse(e, batchResponse); - return; - } - - var responses = new List(); - foreach (var request in batch.Requests) - { - try - { - // pick a random error response for the current request method - var methodStatusCodes = _methodStatusCode[request.Method]; - var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; - - var response = new GraphBatchResponsePayloadResponse - { - Id = request.Id, - Status = (int)errorStatus, - Body = new GraphBatchResponsePayloadResponseBody - { - Error = new GraphBatchResponsePayloadResponseBodyError - { - Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), - Message = "Some error was generated by the proxy.", - } - } - }; - - if (errorStatus == HttpStatusCode.TooManyRequests) - { - var retryAfterDate = DateTime.Now.AddSeconds(_configuration.RetryAfterInSeconds); - var requestUrl = ProxyUtils.GetAbsoluteRequestUrlFromBatch(e.Session.HttpClient.Request.RequestUri, request.Url); - var throttledRequests = e.GlobalData[RetryAfterPlugin.ThrottledRequestsKey] as List; - throttledRequests?.Add(new ThrottlerInfo(GraphUtils.BuildThrottleKey(requestUrl), ShouldThrottle, retryAfterDate)); - response.Headers = new Dictionary { { "Retry-After", _configuration.RetryAfterInSeconds.ToString() } }; - } - - responses.Add(response); - } - catch { } - } - batchResponse.Responses = [.. responses]; - - UpdateProxyBatchResponse(e, batchResponse); - } - - private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) - { - var throttleKeyForRequest = GraphUtils.BuildThrottleKey(request); - return new ThrottlingInfo(throttleKeyForRequest == throttlingKey ? _configuration.RetryAfterInSeconds : 0, "Retry-After"); - } - - private void UpdateProxyResponse(ProxyRequestArgs e, HttpStatusCode errorStatus) - { - SessionEventArgs session = e.Session; - string requestId = Guid.NewGuid().ToString(); - string requestDate = DateTime.Now.ToString(); - Request request = session.HttpClient.Request; - var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); - if (errorStatus == HttpStatusCode.TooManyRequests) - { - var retryAfterDate = DateTime.Now.AddSeconds(_configuration.RetryAfterInSeconds); - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out object? value)) - { - value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); - } - - var throttledRequests = value as List; - throttledRequests?.Add(new ThrottlerInfo(GraphUtils.BuildThrottleKey(request), ShouldThrottle, retryAfterDate)); - headers.Add(new("Retry-After", _configuration.RetryAfterInSeconds.ToString())); - } - - string body = JsonSerializer.Serialize(new GraphErrorResponseBody( - new GraphErrorResponseError - { - Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), - Message = BuildApiErrorMessage(request), - InnerError = new GraphErrorResponseInnerError - { - RequestId = requestId, - Date = requestDate - } - }), - ProxyUtils.JsonSerializerOptions - ); - Logger.LogRequest([$"{(int)errorStatus} {errorStatus}"], MessageType.Chaos, new LoggingContext(e.Session)); - session.GenericResponse(body ?? string.Empty, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); - } - - private void UpdateProxyBatchResponse(ProxyRequestArgs ev, GraphBatchResponsePayload response) - { - // failed batch uses a fixed 424 error status code - var errorStatus = HttpStatusCode.FailedDependency; - - SessionEventArgs session = ev.Session; - string requestId = Guid.NewGuid().ToString(); - string requestDate = DateTime.Now.ToString(); - Request request = session.HttpClient.Request; - var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); - - string body = JsonSerializer.Serialize(response, ProxyUtils.JsonSerializerOptions); - Logger.LogRequest([$"{(int)errorStatus} {errorStatus}"], MessageType.Chaos, new LoggingContext(ev.Session)); - session.GenericResponse(body, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); - } - - private static string BuildApiErrorMessage(Request r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : String.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage(r)) : "")}"; - - public override Option[] GetOptions() - { - var _allowedErrors = new Option>(_allowedErrorsOptionName, "List of errors that Dev Proxy may produce") - { - ArgumentHelpName = "allowed errors", - AllowMultipleArgumentsPerToken = true - }; - _allowedErrors.AddAlias("-a"); - - return [_allowedErrors]; - } - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - PluginEvents.OptionsLoaded += OnOptionsLoaded; - PluginEvents.BeforeRequest += OnRequestAsync; - } - - private void OnOptionsLoaded(object? sender, OptionsLoadedArgs e) - { - InvocationContext context = e.Context; - - // Configure the allowed errors - var allowedErrors = context.ParseResult.GetValueForOption?>(_allowedErrorsOptionName, e.Options); - if (allowedErrors?.Any() ?? false) - _configuration.AllowedErrors = allowedErrors.ToList(); - - if (_configuration.AllowedErrors.Count != 0) - { - foreach (string k in _methodStatusCode.Keys) - { - _methodStatusCode[k] = _methodStatusCode[k].Where(e => _configuration.AllowedErrors.Any(a => (int)e == a)).ToArray(); - } - } - } - - private Task OnRequestAsync(object? sender, ProxyRequestArgs e) - { - var state = e.ResponseState; - if (!e.ResponseState.HasBeenSet - && UrlsToWatch is not null - && e.ShouldExecute(UrlsToWatch)) - { - var failMode = ShouldFail(e); - - if (failMode == GraphRandomErrorFailMode.PassThru && Context.Configuration.Rate != 100) - { - return Task.CompletedTask; - } - if (ProxyUtils.IsGraphBatchUrl(e.Session.HttpClient.Request.RequestUri)) - { - FailBatch(e); - } - else - { - FailResponse(e); - } - state.HasBeenSet = true; - } - - return Task.CompletedTask; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using System.CommandLine; +using System.CommandLine.Invocation; +using System.Net; +using System.Text.Json; +using System.Text.RegularExpressions; +using Titanium.Web.Proxy.EventArguments; +using Titanium.Web.Proxy.Http; +using Titanium.Web.Proxy.Models; +using Microsoft.DevProxy.Plugins.Behavior; + +namespace Microsoft.DevProxy.Plugins.RandomErrors; +internal enum GraphRandomErrorFailMode +{ + Random, + PassThru +} + +public class GraphRandomErrorConfiguration +{ + public List AllowedErrors { get; set; } = []; + public int RetryAfterInSeconds { get; set; } = 5; +} + +public class GraphRandomErrorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private static readonly string _allowedErrorsOptionName = "--allowed-errors"; + private readonly GraphRandomErrorConfiguration _configuration = new(); + + public override string Name => nameof(GraphRandomErrorPlugin); + + private readonly Dictionary _methodStatusCode = new() + { + { + "GET", new[] { + HttpStatusCode.TooManyRequests, + HttpStatusCode.InternalServerError, + HttpStatusCode.BadGateway, + HttpStatusCode.ServiceUnavailable, + HttpStatusCode.GatewayTimeout + } + }, + { + "POST", new[] { + HttpStatusCode.TooManyRequests, + HttpStatusCode.InternalServerError, + HttpStatusCode.BadGateway, + HttpStatusCode.ServiceUnavailable, + HttpStatusCode.GatewayTimeout, + HttpStatusCode.InsufficientStorage + } + }, + { + "PUT", new[] { + HttpStatusCode.TooManyRequests, + HttpStatusCode.InternalServerError, + HttpStatusCode.BadGateway, + HttpStatusCode.ServiceUnavailable, + HttpStatusCode.GatewayTimeout, + HttpStatusCode.InsufficientStorage + } + }, + { + "PATCH", new[] { + HttpStatusCode.TooManyRequests, + HttpStatusCode.InternalServerError, + HttpStatusCode.BadGateway, + HttpStatusCode.ServiceUnavailable, + HttpStatusCode.GatewayTimeout + } + }, + { + "DELETE", new[] { + HttpStatusCode.TooManyRequests, + HttpStatusCode.InternalServerError, + HttpStatusCode.BadGateway, + HttpStatusCode.ServiceUnavailable, + HttpStatusCode.GatewayTimeout, + HttpStatusCode.InsufficientStorage + } + } + }; + private readonly Random _random = new(); + + // uses config to determine if a request should be failed + private GraphRandomErrorFailMode ShouldFail(ProxyRequestArgs e) => _random.Next(1, 100) <= Context.Configuration.Rate ? GraphRandomErrorFailMode.Random : GraphRandomErrorFailMode.PassThru; + + private void FailResponse(ProxyRequestArgs e) + { + // pick a random error response for the current request method + var methodStatusCodes = _methodStatusCode[e.Session.HttpClient.Request.Method ?? "GET"]; + var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; + UpdateProxyResponse(e, errorStatus); + } + + private void FailBatch(ProxyRequestArgs e) + { + var batchResponse = new GraphBatchResponsePayload(); + + var batch = JsonSerializer.Deserialize(e.Session.HttpClient.Request.BodyString, ProxyUtils.JsonSerializerOptions); + if (batch == null) + { + UpdateProxyBatchResponse(e, batchResponse); + return; + } + + var responses = new List(); + foreach (var request in batch.Requests) + { + try + { + // pick a random error response for the current request method + var methodStatusCodes = _methodStatusCode[request.Method]; + var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; + + var response = new GraphBatchResponsePayloadResponse + { + Id = request.Id, + Status = (int)errorStatus, + Body = new GraphBatchResponsePayloadResponseBody + { + Error = new GraphBatchResponsePayloadResponseBodyError + { + Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + Message = "Some error was generated by the proxy.", + } + } + }; + + if (errorStatus == HttpStatusCode.TooManyRequests) + { + var retryAfterDate = DateTime.Now.AddSeconds(_configuration.RetryAfterInSeconds); + var requestUrl = ProxyUtils.GetAbsoluteRequestUrlFromBatch(e.Session.HttpClient.Request.RequestUri, request.Url); + var throttledRequests = e.GlobalData[RetryAfterPlugin.ThrottledRequestsKey] as List; + throttledRequests?.Add(new ThrottlerInfo(GraphUtils.BuildThrottleKey(requestUrl), ShouldThrottle, retryAfterDate)); + response.Headers = new Dictionary { { "Retry-After", _configuration.RetryAfterInSeconds.ToString() } }; + } + + responses.Add(response); + } + catch { } + } + batchResponse.Responses = [.. responses]; + + UpdateProxyBatchResponse(e, batchResponse); + } + + private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) + { + var throttleKeyForRequest = GraphUtils.BuildThrottleKey(request); + return new ThrottlingInfo(throttleKeyForRequest == throttlingKey ? _configuration.RetryAfterInSeconds : 0, "Retry-After"); + } + + private void UpdateProxyResponse(ProxyRequestArgs e, HttpStatusCode errorStatus) + { + SessionEventArgs session = e.Session; + string requestId = Guid.NewGuid().ToString(); + string requestDate = DateTime.Now.ToString(); + Request request = session.HttpClient.Request; + var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); + if (errorStatus == HttpStatusCode.TooManyRequests) + { + var retryAfterDate = DateTime.Now.AddSeconds(_configuration.RetryAfterInSeconds); + if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out object? value)) + { + value = new List(); + e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + } + + var throttledRequests = value as List; + throttledRequests?.Add(new ThrottlerInfo(GraphUtils.BuildThrottleKey(request), ShouldThrottle, retryAfterDate)); + headers.Add(new("Retry-After", _configuration.RetryAfterInSeconds.ToString())); + } + + string body = JsonSerializer.Serialize(new GraphErrorResponseBody( + new GraphErrorResponseError + { + Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + Message = BuildApiErrorMessage(request), + InnerError = new GraphErrorResponseInnerError + { + RequestId = requestId, + Date = requestDate + } + }), + ProxyUtils.JsonSerializerOptions + ); + Logger.LogRequest($"{(int)errorStatus} {errorStatus}", MessageType.Chaos, new LoggingContext(e.Session)); + session.GenericResponse(body ?? string.Empty, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); + } + + private void UpdateProxyBatchResponse(ProxyRequestArgs ev, GraphBatchResponsePayload response) + { + // failed batch uses a fixed 424 error status code + var errorStatus = HttpStatusCode.FailedDependency; + + SessionEventArgs session = ev.Session; + string requestId = Guid.NewGuid().ToString(); + string requestDate = DateTime.Now.ToString(); + Request request = session.HttpClient.Request; + var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); + + string body = JsonSerializer.Serialize(response, ProxyUtils.JsonSerializerOptions); + Logger.LogRequest($"{(int)errorStatus} {errorStatus}", MessageType.Chaos, new LoggingContext(ev.Session)); + session.GenericResponse(body, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); + } + + private static string BuildApiErrorMessage(Request r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : String.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage(r)) : "")}"; + + public override Option[] GetOptions() + { + var _allowedErrors = new Option>(_allowedErrorsOptionName, "List of errors that Dev Proxy may produce") + { + ArgumentHelpName = "allowed errors", + AllowMultipleArgumentsPerToken = true + }; + _allowedErrors.AddAlias("-a"); + + return [_allowedErrors]; + } + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + PluginEvents.OptionsLoaded += OnOptionsLoaded; + PluginEvents.BeforeRequest += OnRequestAsync; + } + + private void OnOptionsLoaded(object? sender, OptionsLoadedArgs e) + { + InvocationContext context = e.Context; + + // Configure the allowed errors + var allowedErrors = context.ParseResult.GetValueForOption?>(_allowedErrorsOptionName, e.Options); + if (allowedErrors?.Any() ?? false) + _configuration.AllowedErrors = allowedErrors.ToList(); + + if (_configuration.AllowedErrors.Count != 0) + { + foreach (string k in _methodStatusCode.Keys) + { + _methodStatusCode[k] = _methodStatusCode[k].Where(e => _configuration.AllowedErrors.Any(a => (int)e == a)).ToArray(); + } + } + } + + private Task OnRequestAsync(object? sender, ProxyRequestArgs e) + { + var state = e.ResponseState; + if (state.HasBeenSet) + { + Logger.LogRequest("Response already set", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (UrlsToWatch is null || + !e.ShouldExecute(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + + + var failMode = ShouldFail(e); + if (failMode == GraphRandomErrorFailMode.PassThru && Context.Configuration.Rate != 100) + { + Logger.LogRequest("Pass through", MessageType.Skipped, new LoggingContext(e.Session)); + return Task.CompletedTask; + } + if (ProxyUtils.IsGraphBatchUrl(e.Session.HttpClient.Request.RequestUri)) + { + FailBatch(e); + } + else + { + FailResponse(e); + } + state.HasBeenSet = true; + + return Task.CompletedTask; + } +} diff --git a/dev-proxy-plugins/RandomErrors/LatencyPlugin.cs b/dev-proxy-plugins/RandomErrors/LatencyPlugin.cs index 91de4cff..74526220 100644 --- a/dev-proxy-plugins/RandomErrors/LatencyPlugin.cs +++ b/dev-proxy-plugins/RandomErrors/LatencyPlugin.cs @@ -1,41 +1,44 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; - -namespace Microsoft.DevProxy.Plugins.RandomErrors; - -public class LatencyConfiguration -{ - public int MinMs { get; set; } = 0; - public int MaxMs { get; set; } = 5000; -} - -public class LatencyPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private readonly LatencyConfiguration _configuration = new(); - - public override string Name => nameof(LatencyPlugin); - private readonly Random _random = new(); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - PluginEvents.BeforeRequest += OnRequestAsync; - } - - private async Task OnRequestAsync(object? sender, ProxyRequestArgs e) - { - if (UrlsToWatch is not null - && e.ShouldExecute(UrlsToWatch)) - { - var delay = _random.Next(_configuration.MinMs, _configuration.MaxMs); - Logger.LogRequest([$"Delaying request for {delay}ms"], MessageType.Chaos, new LoggingContext(e.Session)); - await Task.Delay(delay); - } - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; + +namespace Microsoft.DevProxy.Plugins.RandomErrors; + +public class LatencyConfiguration +{ + public int MinMs { get; set; } = 0; + public int MaxMs { get; set; } = 5000; +} + +public class LatencyPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseProxyPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private readonly LatencyConfiguration _configuration = new(); + + public override string Name => nameof(LatencyPlugin); + private readonly Random _random = new(); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + PluginEvents.BeforeRequest += OnRequestAsync; + } + + private async Task OnRequestAsync(object? sender, ProxyRequestArgs e) + { + if (UrlsToWatch is null || + !e.ShouldExecute(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); + return; + } + + var delay = _random.Next(_configuration.MinMs, _configuration.MaxMs); + Logger.LogRequest($"Delaying request for {delay}ms", MessageType.Chaos, new LoggingContext(e.Session)); + await Task.Delay(delay); + } +} diff --git a/dev-proxy-plugins/RequestLogs/ApiCenterMinimalPermissionsPlugin.cs b/dev-proxy-plugins/RequestLogs/ApiCenterMinimalPermissionsPlugin.cs index b2628f91..3d295a40 100644 --- a/dev-proxy-plugins/RequestLogs/ApiCenterMinimalPermissionsPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/ApiCenterMinimalPermissionsPlugin.cs @@ -1,239 +1,239 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Plugins.ApiCenter; -using Microsoft.DevProxy.Plugins.MinimalPermissions; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.OpenApi.Models; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -public class ApiCenterMinimalPermissionsPluginReportApiResult -{ - public required string ApiId { get; init; } - public required string ApiName { get; init; } - public required string ApiDefinitionId { get; init; } - public required string[] Requests { get; init; } - public required string[] TokenPermissions { get; init; } - public required string[] MinimalPermissions { get; init; } - public required string[] ExcessivePermissions { get; init; } - public required bool UsesMinimalPermissions { get; init; } -} - -public class ApiCenterMinimalPermissionsPluginReport -{ - public required ApiCenterMinimalPermissionsPluginReportApiResult[] Results { get; init; } - public required string[] UnmatchedRequests { get; init; } - public required ApiPermissionError[] Errors { get; init; } -} - -internal class ApiCenterMinimalPermissionsPluginConfiguration -{ - public string SubscriptionId { get; set; } = ""; - public string ResourceGroupName { get; set; } = ""; - public string ServiceName { get; set; } = ""; - public string WorkspaceName { get; set; } = "default"; -} - -public class ApiCenterMinimalPermissionsPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private readonly ApiCenterProductionVersionPluginConfiguration _configuration = new(); - private ApiCenterClient? _apiCenterClient; - private Api[]? _apis; - private Dictionary? _apiDefinitionsByUrl; - - public override string Name => nameof(ApiCenterMinimalPermissionsPlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - - try - { - _apiCenterClient = new( - new() - { - SubscriptionId = _configuration.SubscriptionId, - ResourceGroupName = _configuration.ResourceGroupName, - ServiceName = _configuration.ServiceName, - WorkspaceName = _configuration.WorkspaceName - }, - Logger - ); - } - catch (Exception ex) - { - Logger.LogError(ex, "Failed to create API Center client. The {plugin} will not be used.", Name); - return; - } - - Logger.LogInformation("Plugin {plugin} connecting to Azure...", Name); - try - { - _ = await _apiCenterClient.GetAccessTokenAsync(CancellationToken.None); - } - catch (Exception ex) - { - Logger.LogError(ex, "Failed to authenticate with Azure. The {plugin} will not be used.", Name); - return; - } - Logger.LogDebug("Plugin {plugin} auth confirmed...", Name); - - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) - { - var interceptedRequests = e.RequestLogs - .Where(l => - l.MessageType == MessageType.InterceptedRequest && - !l.MessageLines.First().StartsWith("OPTIONS") && - l.Context?.Session is not null && - l.Context.Session.HttpClient.Request.Headers.Any(h => h.Name.Equals("authorization", StringComparison.OrdinalIgnoreCase)) - ); - if (!interceptedRequests.Any()) - { - Logger.LogDebug("No requests to process"); - return; - } - - Logger.LogInformation("Checking if recorded API requests use minimal permissions as defined in API Center..."); - - Debug.Assert(_apiCenterClient is not null); - - _apis ??= await _apiCenterClient.GetApisAsync(); - if (_apis is null || _apis.Length == 0) - { - Logger.LogInformation("No APIs found in API Center"); - return; - } - - // get all API definitions by URL so that we can easily match - // API requests to API definitions, for permissions lookup - _apiDefinitionsByUrl ??= await _apis.GetApiDefinitionsByUrlAsync(_apiCenterClient, Logger); - - var (requestsByApiDefinition, unmatchedApicRequests) = GetRequestsByApiDefinition(interceptedRequests, _apiDefinitionsByUrl); - - var errors = new List(); - var results = new List(); - var unmatchedRequests = new List( - unmatchedApicRequests.Select(r => r.MessageLines.First()) - ); - - foreach (var (apiDefinition, requests) in requestsByApiDefinition) - { - var minimalPermissions = CheckMinimalPermissions(requests, apiDefinition); - - var api = _apis.FindApiByDefinition(apiDefinition, Logger); - var result = new ApiCenterMinimalPermissionsPluginReportApiResult - { - ApiId = api?.Id ?? "unknown", - ApiName = api?.Properties?.Title ?? "unknown", - ApiDefinitionId = apiDefinition.Id!, - Requests = minimalPermissions.OperationsFromRequests - .Select(o => $"{o.Method} {o.OriginalUrl}") - .Distinct() - .ToArray(), - TokenPermissions = minimalPermissions.TokenPermissions.Distinct().ToArray(), - MinimalPermissions = minimalPermissions.MinimalScopes, - ExcessivePermissions = minimalPermissions.TokenPermissions.Except(minimalPermissions.MinimalScopes).ToArray(), - UsesMinimalPermissions = !minimalPermissions.TokenPermissions.Except(minimalPermissions.MinimalScopes).Any() - }; - results.Add(result); - - var unmatchedApiRequests = minimalPermissions.OperationsFromRequests - .Where(o => minimalPermissions.UnmatchedOperations.Contains($"{o.Method} {o.TokenizedUrl}")) - .Select(o => $"{o.Method} {o.OriginalUrl}"); - unmatchedRequests.AddRange(unmatchedApiRequests); - errors.AddRange(minimalPermissions.Errors); - - if (result.UsesMinimalPermissions) - { - Logger.LogInformation( - "API {apiName} is called with minimal permissions: {minimalPermissions}", - result.ApiName, - string.Join(", ", result.MinimalPermissions) - ); - } - else - { - Logger.LogWarning( - "Calling API {apiName} with excessive permissions: {excessivePermissions}. Minimal permissions are: {minimalPermissions}", - result.ApiName, - string.Join(", ", result.ExcessivePermissions), - string.Join(", ", result.MinimalPermissions) - ); - } - - if (unmatchedApiRequests.Any()) - { - Logger.LogWarning( - "Unmatched requests for API {apiName}:{newLine}- {unmatchedRequests}", - result.ApiName, - Environment.NewLine, - string.Join($"{Environment.NewLine}- ", unmatchedApiRequests) - ); - } - - if (minimalPermissions.Errors.Count != 0) - { - Logger.LogWarning( - "Errors for API {apiName}:{newLine}- {errors}", - result.ApiName, - Environment.NewLine, - string.Join($"{Environment.NewLine}- ", minimalPermissions.Errors.Select(e => $"{e.Request}: {e.Error}")) - ); - } - } - - var report = new ApiCenterMinimalPermissionsPluginReport() - { - Results = [.. results], - UnmatchedRequests = [.. unmatchedRequests], - Errors = [.. errors] - }; - - StoreReport(report, e); - } - - private ApiPermissionsInfo CheckMinimalPermissions(IEnumerable requests, ApiDefinition apiDefinition) - { - Logger.LogInformation("Checking minimal permissions for API {apiName}...", apiDefinition.Definition!.Servers.First().Url); - - return apiDefinition.Definition.CheckMinimalPermissions(requests, Logger); - } - - private (Dictionary> RequestsByApiDefinition, IEnumerable UnmatchedRequests) GetRequestsByApiDefinition(IEnumerable interceptedRequests, Dictionary apiDefinitionsByUrl) - { - var unmatchedRequests = new List(); - var requestsByApiDefinition = new Dictionary>(); - foreach (var request in interceptedRequests) - { - var url = request.MessageLines.First().Split(' ')[1]; - Logger.LogDebug("Matching request {requestUrl} to API definitions...", url); - - var matchingKey = apiDefinitionsByUrl.Keys.FirstOrDefault(url.StartsWith); - if (matchingKey is null) - { - Logger.LogDebug("No matching API definition found for {requestUrl}", url); - unmatchedRequests.Add(request); - continue; - } - - if (!requestsByApiDefinition.TryGetValue(apiDefinitionsByUrl[matchingKey], out List? value)) - { - value = []; - requestsByApiDefinition[apiDefinitionsByUrl[matchingKey]] = value; - } - - value.Add(request); - } - - return (requestsByApiDefinition, unmatchedRequests); - } +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Plugins.ApiCenter; +using Microsoft.DevProxy.Plugins.MinimalPermissions; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.OpenApi.Models; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +public class ApiCenterMinimalPermissionsPluginReportApiResult +{ + public required string ApiId { get; init; } + public required string ApiName { get; init; } + public required string ApiDefinitionId { get; init; } + public required string[] Requests { get; init; } + public required string[] TokenPermissions { get; init; } + public required string[] MinimalPermissions { get; init; } + public required string[] ExcessivePermissions { get; init; } + public required bool UsesMinimalPermissions { get; init; } +} + +public class ApiCenterMinimalPermissionsPluginReport +{ + public required ApiCenterMinimalPermissionsPluginReportApiResult[] Results { get; init; } + public required string[] UnmatchedRequests { get; init; } + public required ApiPermissionError[] Errors { get; init; } +} + +internal class ApiCenterMinimalPermissionsPluginConfiguration +{ + public string SubscriptionId { get; set; } = ""; + public string ResourceGroupName { get; set; } = ""; + public string ServiceName { get; set; } = ""; + public string WorkspaceName { get; set; } = "default"; +} + +public class ApiCenterMinimalPermissionsPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private readonly ApiCenterProductionVersionPluginConfiguration _configuration = new(); + private ApiCenterClient? _apiCenterClient; + private Api[]? _apis; + private Dictionary? _apiDefinitionsByUrl; + + public override string Name => nameof(ApiCenterMinimalPermissionsPlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + + try + { + _apiCenterClient = new( + new() + { + SubscriptionId = _configuration.SubscriptionId, + ResourceGroupName = _configuration.ResourceGroupName, + ServiceName = _configuration.ServiceName, + WorkspaceName = _configuration.WorkspaceName + }, + Logger + ); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to create API Center client. The {plugin} will not be used.", Name); + return; + } + + Logger.LogInformation("Plugin {plugin} connecting to Azure...", Name); + try + { + _ = await _apiCenterClient.GetAccessTokenAsync(CancellationToken.None); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to authenticate with Azure. The {plugin} will not be used.", Name); + return; + } + Logger.LogDebug("Plugin {plugin} auth confirmed...", Name); + + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) + { + var interceptedRequests = e.RequestLogs + .Where(l => + l.MessageType == MessageType.InterceptedRequest && + !l.Message.StartsWith("OPTIONS") && + l.Context?.Session is not null && + l.Context.Session.HttpClient.Request.Headers.Any(h => h.Name.Equals("authorization", StringComparison.OrdinalIgnoreCase)) + ); + if (!interceptedRequests.Any()) + { + Logger.LogDebug("No requests to process"); + return; + } + + Logger.LogInformation("Checking if recorded API requests use minimal permissions as defined in API Center..."); + + Debug.Assert(_apiCenterClient is not null); + + _apis ??= await _apiCenterClient.GetApisAsync(); + if (_apis is null || _apis.Length == 0) + { + Logger.LogInformation("No APIs found in API Center"); + return; + } + + // get all API definitions by URL so that we can easily match + // API requests to API definitions, for permissions lookup + _apiDefinitionsByUrl ??= await _apis.GetApiDefinitionsByUrlAsync(_apiCenterClient, Logger); + + var (requestsByApiDefinition, unmatchedApicRequests) = GetRequestsByApiDefinition(interceptedRequests, _apiDefinitionsByUrl); + + var errors = new List(); + var results = new List(); + var unmatchedRequests = new List( + unmatchedApicRequests.Select(r => r.Message) + ); + + foreach (var (apiDefinition, requests) in requestsByApiDefinition) + { + var minimalPermissions = CheckMinimalPermissions(requests, apiDefinition); + + var api = _apis.FindApiByDefinition(apiDefinition, Logger); + var result = new ApiCenterMinimalPermissionsPluginReportApiResult + { + ApiId = api?.Id ?? "unknown", + ApiName = api?.Properties?.Title ?? "unknown", + ApiDefinitionId = apiDefinition.Id!, + Requests = minimalPermissions.OperationsFromRequests + .Select(o => $"{o.Method} {o.OriginalUrl}") + .Distinct() + .ToArray(), + TokenPermissions = minimalPermissions.TokenPermissions.Distinct().ToArray(), + MinimalPermissions = minimalPermissions.MinimalScopes, + ExcessivePermissions = minimalPermissions.TokenPermissions.Except(minimalPermissions.MinimalScopes).ToArray(), + UsesMinimalPermissions = !minimalPermissions.TokenPermissions.Except(minimalPermissions.MinimalScopes).Any() + }; + results.Add(result); + + var unmatchedApiRequests = minimalPermissions.OperationsFromRequests + .Where(o => minimalPermissions.UnmatchedOperations.Contains($"{o.Method} {o.TokenizedUrl}")) + .Select(o => $"{o.Method} {o.OriginalUrl}"); + unmatchedRequests.AddRange(unmatchedApiRequests); + errors.AddRange(minimalPermissions.Errors); + + if (result.UsesMinimalPermissions) + { + Logger.LogInformation( + "API {apiName} is called with minimal permissions: {minimalPermissions}", + result.ApiName, + string.Join(", ", result.MinimalPermissions) + ); + } + else + { + Logger.LogWarning( + "Calling API {apiName} with excessive permissions: {excessivePermissions}. Minimal permissions are: {minimalPermissions}", + result.ApiName, + string.Join(", ", result.ExcessivePermissions), + string.Join(", ", result.MinimalPermissions) + ); + } + + if (unmatchedApiRequests.Any()) + { + Logger.LogWarning( + "Unmatched requests for API {apiName}:{newLine}- {unmatchedRequests}", + result.ApiName, + Environment.NewLine, + string.Join($"{Environment.NewLine}- ", unmatchedApiRequests) + ); + } + + if (minimalPermissions.Errors.Count != 0) + { + Logger.LogWarning( + "Errors for API {apiName}:{newLine}- {errors}", + result.ApiName, + Environment.NewLine, + string.Join($"{Environment.NewLine}- ", minimalPermissions.Errors.Select(e => $"{e.Request}: {e.Error}")) + ); + } + } + + var report = new ApiCenterMinimalPermissionsPluginReport() + { + Results = [.. results], + UnmatchedRequests = [.. unmatchedRequests], + Errors = [.. errors] + }; + + StoreReport(report, e); + } + + private ApiPermissionsInfo CheckMinimalPermissions(IEnumerable requests, ApiDefinition apiDefinition) + { + Logger.LogInformation("Checking minimal permissions for API {apiName}...", apiDefinition.Definition!.Servers.First().Url); + + return apiDefinition.Definition.CheckMinimalPermissions(requests, Logger); + } + + private (Dictionary> RequestsByApiDefinition, IEnumerable UnmatchedRequests) GetRequestsByApiDefinition(IEnumerable interceptedRequests, Dictionary apiDefinitionsByUrl) + { + var unmatchedRequests = new List(); + var requestsByApiDefinition = new Dictionary>(); + foreach (var request in interceptedRequests) + { + var url = request.Message.Split(' ')[1]; + Logger.LogDebug("Matching request {requestUrl} to API definitions...", url); + + var matchingKey = apiDefinitionsByUrl.Keys.FirstOrDefault(url.StartsWith); + if (matchingKey is null) + { + Logger.LogDebug("No matching API definition found for {requestUrl}", url); + unmatchedRequests.Add(request); + continue; + } + + if (!requestsByApiDefinition.TryGetValue(apiDefinitionsByUrl[matchingKey], out List? value)) + { + value = []; + requestsByApiDefinition[apiDefinitionsByUrl[matchingKey]] = value; + } + + value.Add(request); + } + + return (requestsByApiDefinition, unmatchedRequests); + } } \ No newline at end of file diff --git a/dev-proxy-plugins/RequestLogs/ApiCenterOnboardingPlugin.cs b/dev-proxy-plugins/RequestLogs/ApiCenterOnboardingPlugin.cs index ea3f5eb3..e6cbaf5c 100644 --- a/dev-proxy-plugins/RequestLogs/ApiCenterOnboardingPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/ApiCenterOnboardingPlugin.cs @@ -1,389 +1,389 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Plugins.ApiCenter; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.OpenApi.Models; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -public class ApiCenterOnboardingPluginReportExistingApiInfo -{ - public required string MethodAndUrl { get; init; } - public required string ApiDefinitionId { get; init; } - public required string OperationId { get; init; } -} - -public class ApiCenterOnboardingPluginReportNewApiInfo -{ - public required string Method { get; init; } - public required string Url { get; init; } -} - -public class ApiCenterOnboardingPluginReport -{ - public required ApiCenterOnboardingPluginReportExistingApiInfo[] ExistingApis { get; init; } - public required ApiCenterOnboardingPluginReportNewApiInfo[] NewApis { get; init; } -} - -internal class ApiCenterOnboardingPluginConfiguration -{ - public string SubscriptionId { get; set; } = ""; - public string ResourceGroupName { get; set; } = ""; - public string ServiceName { get; set; } = ""; - public string WorkspaceName { get; set; } = "default"; - public bool CreateApicEntryForNewApis { get; set; } = true; -} - -public class ApiCenterOnboardingPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private readonly ApiCenterOnboardingPluginConfiguration _configuration = new(); - private ApiCenterClient? _apiCenterClient; - private Api[]? _apis; - private Dictionary? _apiDefinitionsByUrl; - - public override string Name => nameof(ApiCenterOnboardingPlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - - try - { - _apiCenterClient = new( - new() - { - SubscriptionId = _configuration.SubscriptionId, - ResourceGroupName = _configuration.ResourceGroupName, - ServiceName = _configuration.ServiceName, - WorkspaceName = _configuration.WorkspaceName - }, - Logger - ); - } - catch (Exception ex) - { - Logger.LogError(ex, "Failed to create API Center client. The {plugin} will not be used.", Name); - return; - } - - Logger.LogInformation("Plugin {plugin} connecting to Azure...", Name); - try - { - _ = await _apiCenterClient.GetAccessTokenAsync(CancellationToken.None); - } - catch (Exception ex) - { - Logger.LogError(ex, "Failed to authenticate with Azure. The {plugin} will not be used.", Name); - return; - } - Logger.LogDebug("Plugin {plugin} auth confirmed...", Name); - - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) - { - if (!e.RequestLogs.Any()) - { - Logger.LogDebug("No requests to process"); - return; - } - - Logger.LogInformation("Checking if recorded API requests belong to APIs in API Center..."); - - Debug.Assert(_apiCenterClient is not null); - - _apis ??= await _apiCenterClient.GetApisAsync(); - - if (_apis == null || _apis.Length == 0) - { - Logger.LogInformation("No APIs found in API Center"); - return; - } - - _apiDefinitionsByUrl ??= await _apis.GetApiDefinitionsByUrlAsync(_apiCenterClient, Logger); - - var newApis = new List<(string method, string url)>(); - var interceptedRequests = e.RequestLogs - .Where(l => l.MessageType == MessageType.InterceptedRequest) - .Select(request => - { - var methodAndUrl = request.MessageLines.First().Split(' '); - return (method: methodAndUrl[0], url: methodAndUrl[1]); - }) - .Where(r => !r.method.Equals("OPTIONS", StringComparison.OrdinalIgnoreCase)) - .Distinct(); - - var existingApis = new List(); - - foreach (var request in interceptedRequests) - { - var (method, url) = request; - - Logger.LogDebug("Processing request {method} {url}...", method, url); - - var apiDefinition = _apiDefinitionsByUrl.FirstOrDefault(x => - url.StartsWith(x.Key, StringComparison.OrdinalIgnoreCase)).Value; - if (apiDefinition is null || - apiDefinition.Id is null) - { - Logger.LogDebug("No matching API definition not found for {url}. Adding new API...", url); - newApis.Add((method, url)); - continue; - } - - await apiDefinition.LoadOpenApiDefinitionAsync(_apiCenterClient, Logger); - - if (apiDefinition.Definition is null) - { - Logger.LogDebug("API definition not found for {url} so nothing to compare to. Adding new API...", url); - newApis.Add(new(method, url)); - continue; - } - - var pathItem = apiDefinition.Definition.FindMatchingPathItem(url, Logger); - if (pathItem is null) - { - Logger.LogDebug("No matching path found for {url}. Adding new API...", url); - newApis.Add(new(method, url)); - continue; - } - - var operation = pathItem.Value.Value.Operations.FirstOrDefault(x => - x.Key.ToString().Equals(method, StringComparison.OrdinalIgnoreCase)).Value; - if (operation is null) - { - Logger.LogDebug("No matching operation found for {method} {url}. Adding new API...", method, url); - newApis.Add(new(method, url)); - continue; - } - - existingApis.Add(new() - { - MethodAndUrl = $"{method} {url}", - ApiDefinitionId = apiDefinition.Id, - OperationId = operation.OperationId - }); - } - - if (newApis.Count == 0) - { - Logger.LogInformation("No new APIs found"); - StoreReport(new ApiCenterOnboardingPluginReport - { - ExistingApis = existingApis.ToArray(), - NewApis = [] - }, e); - return; - } - - // dedupe newApis - newApis = newApis.Distinct().ToList(); - - StoreReport(new ApiCenterOnboardingPluginReport - { - ExistingApis = [.. existingApis], - NewApis = newApis.Select(a => new ApiCenterOnboardingPluginReportNewApiInfo - { - Method = a.method, - Url = a.url - }).ToArray() - }, e); - - var apisPerSchemeAndHost = newApis.GroupBy(x => - { - var u = new Uri(x.url); - return u.GetLeftPart(UriPartial.Authority); - }); - - var newApisMessageChunks = new List(["New APIs that aren't registered in Azure API Center:", ""]); - foreach (var apiPerHost in apisPerSchemeAndHost) - { - newApisMessageChunks.Add($"{apiPerHost.Key}:"); - newApisMessageChunks.AddRange(apiPerHost.Select(a => $" {a.method} {a.url}")); - } - - Logger.LogInformation(string.Join(Environment.NewLine, newApisMessageChunks)); - - if (!_configuration.CreateApicEntryForNewApis) - { - return; - } - - var generatedOpenApiSpecs = e.GlobalData.TryGetValue(OpenApiSpecGeneratorPlugin.GeneratedOpenApiSpecsKey, out var specs) ? specs as Dictionary : new(); - await CreateApisInApiCenterAsync(apisPerSchemeAndHost, generatedOpenApiSpecs!); - } - - async Task CreateApisInApiCenterAsync(IEnumerable> apisPerHost, Dictionary generatedOpenApiSpecs) - { - Logger.LogInformation("Creating new API entries in API Center..."); - - foreach (var apiPerHost in apisPerHost) - { - var schemeAndHost = apiPerHost.Key; - - var api = await CreateApiAsync(schemeAndHost, apiPerHost); - if (api is null) - { - continue; - } - - Debug.Assert(api.Id is not null); - - if (!generatedOpenApiSpecs.TryGetValue(schemeAndHost, out var openApiSpecFilePath)) - { - Logger.LogDebug("No OpenAPI spec found for {host}", schemeAndHost); - continue; - } - - var apiVersion = await CreateApiVersionAsync(api.Id); - if (apiVersion is null) - { - continue; - } - - Debug.Assert(apiVersion.Id is not null); - - var apiDefinition = await CreateApiDefinitionAsync(apiVersion.Id); - if (apiDefinition is null) - { - continue; - } - - Debug.Assert(apiDefinition.Id is not null); - - await ImportApiDefinitionAsync(apiDefinition.Id, openApiSpecFilePath); - } - } - - async Task CreateApiAsync(string schemeAndHost, IEnumerable<(string method, string url)> apiRequests) - { - Debug.Assert(_apiCenterClient is not null); - - // trim to 50 chars which is max length for API name - var apiName = $"new-{schemeAndHost.Replace(".", "-").Replace("http://", "").Replace("https://", "")}-{DateTimeOffset.UtcNow.ToUnixTimeSeconds()}".MaxLength(50); - Logger.LogInformation(" Creating API {apiName} for {host}...", apiName, schemeAndHost); - - var title = $"New APIs: {schemeAndHost}"; - var description = new List(["New APIs discovered by Dev Proxy", ""]); - description.AddRange(apiRequests.Select(a => $" {a.method} {a.url}").ToArray()); - var api = new Api - { - Properties = new() - { - Title = title, - Description = string.Join(Environment.NewLine, description), - Kind = ApiKind.REST - } - }; - - var newApi = await _apiCenterClient.PutApiAsync(api, apiName); - if (newApi is not null) - { - Logger.LogDebug("API created successfully"); - } - else - { - Logger.LogError("Failed to create API {apiName} for {host}", apiName, schemeAndHost); - } - - return newApi; - } - - async Task CreateApiVersionAsync(string apiId) - { - Debug.Assert(_apiCenterClient is not null); - - Logger.LogDebug(" Creating API version for {api}...", apiId); - - var apiVersion = new ApiVersion - { - Properties = new() - { - Title = "v1.0", - LifecycleStage = ApiLifecycleStage.Production - } - }; - - var newApiVersion = await _apiCenterClient.PutVersionAsync(apiVersion, apiId, "v1-0"); - if (newApiVersion is not null) - { - Logger.LogDebug("API version created successfully"); - } - else - { - Logger.LogError("Failed to create API version for {api}", apiId.Substring(apiId.LastIndexOf('/'))); - } - - return newApiVersion; - } - - async Task CreateApiDefinitionAsync(string apiVersionId) - { - Debug.Assert(_apiCenterClient is not null); - - Logger.LogDebug(" Creating API definition for {api}...", apiVersionId); - - var apiDefinition = new ApiDefinition - { - Properties = new() - { - Title = "OpenAPI" - } - }; - var newApiDefinition = await _apiCenterClient.PutDefinitionAsync(apiDefinition, apiVersionId, "openapi"); - if (newApiDefinition is not null) - { - Logger.LogDebug("API definition created successfully"); - } - else - { - Logger.LogError("Failed to create API definition for {apiVersion}", apiVersionId); - } - - return newApiDefinition; - } - - async Task ImportApiDefinitionAsync(string apiDefinitionId, string openApiSpecFilePath) - { - Debug.Assert(_apiCenterClient is not null); - - Logger.LogDebug(" Importing API definition for {api}...", apiDefinitionId); - - var openApiSpec = File.ReadAllText(openApiSpecFilePath); - var apiSpecImport = new ApiSpecImport - { - Format = ApiSpecImportResultFormat.Inline, - Value = openApiSpec, - Specification = new() - { - Name = "openapi", - Version = "3.0.1" - } - }; - var res = await _apiCenterClient.PostImportSpecificationAsync(apiSpecImport, apiDefinitionId); - if (res.IsSuccessStatusCode) - { - Logger.LogDebug("API definition imported successfully"); - } - else - { - var resContent = res.ReasonPhrase; - try - { - resContent = await res.Content.ReadAsStringAsync(); - } - catch - { - } - - Logger.LogError("Failed to import API definition for {apiDefinition}. Status: {status}, reason: {reason}", apiDefinitionId, res.StatusCode, resContent); - } - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Plugins.ApiCenter; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.OpenApi.Models; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +public class ApiCenterOnboardingPluginReportExistingApiInfo +{ + public required string MethodAndUrl { get; init; } + public required string ApiDefinitionId { get; init; } + public required string OperationId { get; init; } +} + +public class ApiCenterOnboardingPluginReportNewApiInfo +{ + public required string Method { get; init; } + public required string Url { get; init; } +} + +public class ApiCenterOnboardingPluginReport +{ + public required ApiCenterOnboardingPluginReportExistingApiInfo[] ExistingApis { get; init; } + public required ApiCenterOnboardingPluginReportNewApiInfo[] NewApis { get; init; } +} + +internal class ApiCenterOnboardingPluginConfiguration +{ + public string SubscriptionId { get; set; } = ""; + public string ResourceGroupName { get; set; } = ""; + public string ServiceName { get; set; } = ""; + public string WorkspaceName { get; set; } = "default"; + public bool CreateApicEntryForNewApis { get; set; } = true; +} + +public class ApiCenterOnboardingPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private readonly ApiCenterOnboardingPluginConfiguration _configuration = new(); + private ApiCenterClient? _apiCenterClient; + private Api[]? _apis; + private Dictionary? _apiDefinitionsByUrl; + + public override string Name => nameof(ApiCenterOnboardingPlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + + try + { + _apiCenterClient = new( + new() + { + SubscriptionId = _configuration.SubscriptionId, + ResourceGroupName = _configuration.ResourceGroupName, + ServiceName = _configuration.ServiceName, + WorkspaceName = _configuration.WorkspaceName + }, + Logger + ); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to create API Center client. The {plugin} will not be used.", Name); + return; + } + + Logger.LogInformation("Plugin {plugin} connecting to Azure...", Name); + try + { + _ = await _apiCenterClient.GetAccessTokenAsync(CancellationToken.None); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to authenticate with Azure. The {plugin} will not be used.", Name); + return; + } + Logger.LogDebug("Plugin {plugin} auth confirmed...", Name); + + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) + { + if (!e.RequestLogs.Any()) + { + Logger.LogDebug("No requests to process"); + return; + } + + Logger.LogInformation("Checking if recorded API requests belong to APIs in API Center..."); + + Debug.Assert(_apiCenterClient is not null); + + _apis ??= await _apiCenterClient.GetApisAsync(); + + if (_apis == null || _apis.Length == 0) + { + Logger.LogInformation("No APIs found in API Center"); + return; + } + + _apiDefinitionsByUrl ??= await _apis.GetApiDefinitionsByUrlAsync(_apiCenterClient, Logger); + + var newApis = new List<(string method, string url)>(); + var interceptedRequests = e.RequestLogs + .Where(l => l.MessageType == MessageType.InterceptedRequest) + .Select(request => + { + var methodAndUrl = request.Message.Split(' '); + return (method: methodAndUrl[0], url: methodAndUrl[1]); + }) + .Where(r => !r.method.Equals("OPTIONS", StringComparison.OrdinalIgnoreCase)) + .Distinct(); + + var existingApis = new List(); + + foreach (var request in interceptedRequests) + { + var (method, url) = request; + + Logger.LogDebug("Processing request {method} {url}...", method, url); + + var apiDefinition = _apiDefinitionsByUrl.FirstOrDefault(x => + url.StartsWith(x.Key, StringComparison.OrdinalIgnoreCase)).Value; + if (apiDefinition is null || + apiDefinition.Id is null) + { + Logger.LogDebug("No matching API definition not found for {url}. Adding new API...", url); + newApis.Add((method, url)); + continue; + } + + await apiDefinition.LoadOpenApiDefinitionAsync(_apiCenterClient, Logger); + + if (apiDefinition.Definition is null) + { + Logger.LogDebug("API definition not found for {url} so nothing to compare to. Adding new API...", url); + newApis.Add(new(method, url)); + continue; + } + + var pathItem = apiDefinition.Definition.FindMatchingPathItem(url, Logger); + if (pathItem is null) + { + Logger.LogDebug("No matching path found for {url}. Adding new API...", url); + newApis.Add(new(method, url)); + continue; + } + + var operation = pathItem.Value.Value.Operations.FirstOrDefault(x => + x.Key.ToString().Equals(method, StringComparison.OrdinalIgnoreCase)).Value; + if (operation is null) + { + Logger.LogDebug("No matching operation found for {method} {url}. Adding new API...", method, url); + newApis.Add(new(method, url)); + continue; + } + + existingApis.Add(new() + { + MethodAndUrl = $"{method} {url}", + ApiDefinitionId = apiDefinition.Id, + OperationId = operation.OperationId + }); + } + + if (newApis.Count == 0) + { + Logger.LogInformation("No new APIs found"); + StoreReport(new ApiCenterOnboardingPluginReport + { + ExistingApis = existingApis.ToArray(), + NewApis = [] + }, e); + return; + } + + // dedupe newApis + newApis = newApis.Distinct().ToList(); + + StoreReport(new ApiCenterOnboardingPluginReport + { + ExistingApis = [.. existingApis], + NewApis = newApis.Select(a => new ApiCenterOnboardingPluginReportNewApiInfo + { + Method = a.method, + Url = a.url + }).ToArray() + }, e); + + var apisPerSchemeAndHost = newApis.GroupBy(x => + { + var u = new Uri(x.url); + return u.GetLeftPart(UriPartial.Authority); + }); + + var newApisMessageChunks = new List(["New APIs that aren't registered in Azure API Center:", ""]); + foreach (var apiPerHost in apisPerSchemeAndHost) + { + newApisMessageChunks.Add($"{apiPerHost.Key}:"); + newApisMessageChunks.AddRange(apiPerHost.Select(a => $" {a.method} {a.url}")); + } + + Logger.LogInformation(string.Join(Environment.NewLine, newApisMessageChunks)); + + if (!_configuration.CreateApicEntryForNewApis) + { + return; + } + + var generatedOpenApiSpecs = e.GlobalData.TryGetValue(OpenApiSpecGeneratorPlugin.GeneratedOpenApiSpecsKey, out var specs) ? specs as Dictionary : new(); + await CreateApisInApiCenterAsync(apisPerSchemeAndHost, generatedOpenApiSpecs!); + } + + async Task CreateApisInApiCenterAsync(IEnumerable> apisPerHost, Dictionary generatedOpenApiSpecs) + { + Logger.LogInformation("Creating new API entries in API Center..."); + + foreach (var apiPerHost in apisPerHost) + { + var schemeAndHost = apiPerHost.Key; + + var api = await CreateApiAsync(schemeAndHost, apiPerHost); + if (api is null) + { + continue; + } + + Debug.Assert(api.Id is not null); + + if (!generatedOpenApiSpecs.TryGetValue(schemeAndHost, out var openApiSpecFilePath)) + { + Logger.LogDebug("No OpenAPI spec found for {host}", schemeAndHost); + continue; + } + + var apiVersion = await CreateApiVersionAsync(api.Id); + if (apiVersion is null) + { + continue; + } + + Debug.Assert(apiVersion.Id is not null); + + var apiDefinition = await CreateApiDefinitionAsync(apiVersion.Id); + if (apiDefinition is null) + { + continue; + } + + Debug.Assert(apiDefinition.Id is not null); + + await ImportApiDefinitionAsync(apiDefinition.Id, openApiSpecFilePath); + } + } + + async Task CreateApiAsync(string schemeAndHost, IEnumerable<(string method, string url)> apiRequests) + { + Debug.Assert(_apiCenterClient is not null); + + // trim to 50 chars which is max length for API name + var apiName = $"new-{schemeAndHost.Replace(".", "-").Replace("http://", "").Replace("https://", "")}-{DateTimeOffset.UtcNow.ToUnixTimeSeconds()}".MaxLength(50); + Logger.LogInformation(" Creating API {apiName} for {host}...", apiName, schemeAndHost); + + var title = $"New APIs: {schemeAndHost}"; + var description = new List(["New APIs discovered by Dev Proxy", ""]); + description.AddRange(apiRequests.Select(a => $" {a.method} {a.url}").ToArray()); + var api = new Api + { + Properties = new() + { + Title = title, + Description = string.Join(Environment.NewLine, description), + Kind = ApiKind.REST + } + }; + + var newApi = await _apiCenterClient.PutApiAsync(api, apiName); + if (newApi is not null) + { + Logger.LogDebug("API created successfully"); + } + else + { + Logger.LogError("Failed to create API {apiName} for {host}", apiName, schemeAndHost); + } + + return newApi; + } + + async Task CreateApiVersionAsync(string apiId) + { + Debug.Assert(_apiCenterClient is not null); + + Logger.LogDebug(" Creating API version for {api}...", apiId); + + var apiVersion = new ApiVersion + { + Properties = new() + { + Title = "v1.0", + LifecycleStage = ApiLifecycleStage.Production + } + }; + + var newApiVersion = await _apiCenterClient.PutVersionAsync(apiVersion, apiId, "v1-0"); + if (newApiVersion is not null) + { + Logger.LogDebug("API version created successfully"); + } + else + { + Logger.LogError("Failed to create API version for {api}", apiId.Substring(apiId.LastIndexOf('/'))); + } + + return newApiVersion; + } + + async Task CreateApiDefinitionAsync(string apiVersionId) + { + Debug.Assert(_apiCenterClient is not null); + + Logger.LogDebug(" Creating API definition for {api}...", apiVersionId); + + var apiDefinition = new ApiDefinition + { + Properties = new() + { + Title = "OpenAPI" + } + }; + var newApiDefinition = await _apiCenterClient.PutDefinitionAsync(apiDefinition, apiVersionId, "openapi"); + if (newApiDefinition is not null) + { + Logger.LogDebug("API definition created successfully"); + } + else + { + Logger.LogError("Failed to create API definition for {apiVersion}", apiVersionId); + } + + return newApiDefinition; + } + + async Task ImportApiDefinitionAsync(string apiDefinitionId, string openApiSpecFilePath) + { + Debug.Assert(_apiCenterClient is not null); + + Logger.LogDebug(" Importing API definition for {api}...", apiDefinitionId); + + var openApiSpec = File.ReadAllText(openApiSpecFilePath); + var apiSpecImport = new ApiSpecImport + { + Format = ApiSpecImportResultFormat.Inline, + Value = openApiSpec, + Specification = new() + { + Name = "openapi", + Version = "3.0.1" + } + }; + var res = await _apiCenterClient.PostImportSpecificationAsync(apiSpecImport, apiDefinitionId); + if (res.IsSuccessStatusCode) + { + Logger.LogDebug("API definition imported successfully"); + } + else + { + var resContent = res.ReasonPhrase; + try + { + resContent = await res.Content.ReadAsStringAsync(); + } + catch + { + } + + Logger.LogError("Failed to import API definition for {apiDefinition}. Status: {status}, reason: {reason}", apiDefinitionId, res.StatusCode, resContent); + } + } +} diff --git a/dev-proxy-plugins/RequestLogs/ApiCenterProductionVersionPlugin.cs b/dev-proxy-plugins/RequestLogs/ApiCenterProductionVersionPlugin.cs index cf267d46..a8a4fe05 100644 --- a/dev-proxy-plugins/RequestLogs/ApiCenterProductionVersionPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/ApiCenterProductionVersionPlugin.cs @@ -1,236 +1,236 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics; -using System.Text.Json.Serialization; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Plugins.ApiCenter; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -public enum ApiCenterProductionVersionPluginReportItemStatus -{ - NotRegistered, - NonProduction, - Production -} - -public class ApiCenterProductionVersionPluginReportItem -{ - public required string Method { get; init; } - public required string Url { get; init; } - [JsonConverter(typeof(JsonStringEnumConverter))] - public required ApiCenterProductionVersionPluginReportItemStatus Status { get; init; } - public string? Recommendation { get; init; } -} - -public class ApiCenterProductionVersionPluginReport : List; - -internal class ApiCenterProductionVersionPluginConfiguration -{ - public string SubscriptionId { get; set; } = ""; - public string ResourceGroupName { get; set; } = ""; - public string ServiceName { get; set; } = ""; - public string WorkspaceName { get; set; } = "default"; -} - -public class ApiCenterProductionVersionPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - private readonly ApiCenterProductionVersionPluginConfiguration _configuration = new(); - private ApiCenterClient? _apiCenterClient; - private Api[]? _apis; - - public override string Name => nameof(ApiCenterProductionVersionPlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - - try - { - _apiCenterClient = new( - new() - { - SubscriptionId = _configuration.SubscriptionId, - ResourceGroupName = _configuration.ResourceGroupName, - ServiceName = _configuration.ServiceName, - WorkspaceName = _configuration.WorkspaceName - }, - Logger - ); - } - catch (Exception ex) - { - Logger.LogError(ex, "Failed to create API Center client. The {plugin} will not be used.", Name); - return; - } - - Logger.LogInformation("Plugin {plugin} connecting to Azure...", Name); - try - { - _ = await _apiCenterClient.GetAccessTokenAsync(CancellationToken.None); - } - catch (Exception ex) - { - Logger.LogError(ex, "Failed to authenticate with Azure. The {plugin} will not be used.", Name); - return; - } - Logger.LogDebug("Plugin {plugin} auth confirmed...", Name); - - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) - { - var interceptedRequests = e.RequestLogs - .Where( - l => l.MessageType == MessageType.InterceptedRequest && - l.Context?.Session is not null - ); - if (!interceptedRequests.Any()) - { - Logger.LogDebug("No requests to process"); - return; - } - - Logger.LogInformation("Checking if recorded API requests use production APIs as defined in API Center..."); - - Debug.Assert(_apiCenterClient is not null); - - _apis ??= await _apiCenterClient.GetApisAsync(); - - if (_apis == null || _apis.Length == 0) - { - Logger.LogInformation("No APIs found in API Center"); - return; - } - - foreach (var api in _apis) - { - Debug.Assert(api.Id is not null); - - await api.LoadVersionsAsync(_apiCenterClient); - if (api.Versions == null || api.Versions.Length == 0) - { - Logger.LogInformation("No versions found for {api}", api.Properties?.Title); - continue; - } - - foreach (var versionFromApiCenter in api.Versions) - { - Debug.Assert(versionFromApiCenter.Id is not null); - - await versionFromApiCenter.LoadDefinitionsAsync(_apiCenterClient); - if (versionFromApiCenter.Definitions == null || - versionFromApiCenter.Definitions.Length == 0) - { - Logger.LogDebug("No definitions found for version {versionId}", versionFromApiCenter.Id); - continue; - } - - var definitions = new List(); - foreach (var definitionFromApiCenter in versionFromApiCenter.Definitions) - { - Debug.Assert(definitionFromApiCenter.Id is not null); - - await definitionFromApiCenter.LoadOpenApiDefinitionAsync(_apiCenterClient, Logger); - - if (definitionFromApiCenter.Definition is null) - { - Logger.LogDebug("API definition not found for {definitionId}", definitionFromApiCenter.Id); - continue; - } - - if (!definitionFromApiCenter.Definition.Servers.Any()) - { - Logger.LogDebug("No servers found for API definition {definitionId}", definitionFromApiCenter.Id); - continue; - } - - definitions.Add(definitionFromApiCenter); - } - - versionFromApiCenter.Definitions = [.. definitions]; - } - } - - Logger.LogInformation("Analyzing recorded requests..."); - - var report = new ApiCenterProductionVersionPluginReport(); - - foreach (var request in interceptedRequests) - { - var methodAndUrlString = request.MessageLines.First(); - var methodAndUrl = methodAndUrlString.Split(' '); - var (method, url) = (methodAndUrl[0], methodAndUrl[1]); - if (method == "OPTIONS") - { - continue; - } - - var api = _apis.FindApiByUrl(url, Logger); - if (api == null) - { - report.Add(new() - { - Method = method, - Url = url, - Status = ApiCenterProductionVersionPluginReportItemStatus.NotRegistered - }); - continue; - } - - var version = api.GetVersion(request, url, Logger); - if (version is null) - { - report.Add(new() - { - Method = method, - Url = url, - Status = ApiCenterProductionVersionPluginReportItemStatus.NotRegistered - }); - continue; - } - - Debug.Assert(version.Properties is not null); - var lifecycleStage = version.Properties.LifecycleStage; - - if (lifecycleStage != ApiLifecycleStage.Production) - { - Debug.Assert(api.Versions is not null); - - var productionVersions = api.Versions - .Where(v => v.Properties?.LifecycleStage == ApiLifecycleStage.Production) - .Select(v => v.Properties?.Title); - - var recommendation = productionVersions.Any() ? - string.Format("Request {0} uses API version {1} which is defined as {2}. Upgrade to a production version of the API. Recommended versions: {3}", methodAndUrlString, api.Versions.First(v => v.Properties?.LifecycleStage == lifecycleStage).Properties?.Title, lifecycleStage, string.Join(", ", productionVersions)) : - string.Format("Request {0} uses API version {1} which is defined as {2}.", methodAndUrlString, api.Versions.First(v => v.Properties?.LifecycleStage == lifecycleStage).Properties?.Title, lifecycleStage); - - Logger.LogWarning(recommendation); - report.Add(new() - { - Method = method, - Url = url, - Status = ApiCenterProductionVersionPluginReportItemStatus.NonProduction, - Recommendation = recommendation - }); - } - else - { - report.Add(new() - { - Method = method, - Url = url, - Status = ApiCenterProductionVersionPluginReportItemStatus.Production - }); - } - } - - StoreReport(report, e); - } +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Text.Json.Serialization; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Plugins.ApiCenter; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +public enum ApiCenterProductionVersionPluginReportItemStatus +{ + NotRegistered, + NonProduction, + Production +} + +public class ApiCenterProductionVersionPluginReportItem +{ + public required string Method { get; init; } + public required string Url { get; init; } + [JsonConverter(typeof(JsonStringEnumConverter))] + public required ApiCenterProductionVersionPluginReportItemStatus Status { get; init; } + public string? Recommendation { get; init; } +} + +public class ApiCenterProductionVersionPluginReport : List; + +internal class ApiCenterProductionVersionPluginConfiguration +{ + public string SubscriptionId { get; set; } = ""; + public string ResourceGroupName { get; set; } = ""; + public string ServiceName { get; set; } = ""; + public string WorkspaceName { get; set; } = "default"; +} + +public class ApiCenterProductionVersionPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + private readonly ApiCenterProductionVersionPluginConfiguration _configuration = new(); + private ApiCenterClient? _apiCenterClient; + private Api[]? _apis; + + public override string Name => nameof(ApiCenterProductionVersionPlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + + try + { + _apiCenterClient = new( + new() + { + SubscriptionId = _configuration.SubscriptionId, + ResourceGroupName = _configuration.ResourceGroupName, + ServiceName = _configuration.ServiceName, + WorkspaceName = _configuration.WorkspaceName + }, + Logger + ); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to create API Center client. The {plugin} will not be used.", Name); + return; + } + + Logger.LogInformation("Plugin {plugin} connecting to Azure...", Name); + try + { + _ = await _apiCenterClient.GetAccessTokenAsync(CancellationToken.None); + } + catch (Exception ex) + { + Logger.LogError(ex, "Failed to authenticate with Azure. The {plugin} will not be used.", Name); + return; + } + Logger.LogDebug("Plugin {plugin} auth confirmed...", Name); + + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) + { + var interceptedRequests = e.RequestLogs + .Where( + l => l.MessageType == MessageType.InterceptedRequest && + l.Context?.Session is not null + ); + if (!interceptedRequests.Any()) + { + Logger.LogDebug("No requests to process"); + return; + } + + Logger.LogInformation("Checking if recorded API requests use production APIs as defined in API Center..."); + + Debug.Assert(_apiCenterClient is not null); + + _apis ??= await _apiCenterClient.GetApisAsync(); + + if (_apis == null || _apis.Length == 0) + { + Logger.LogInformation("No APIs found in API Center"); + return; + } + + foreach (var api in _apis) + { + Debug.Assert(api.Id is not null); + + await api.LoadVersionsAsync(_apiCenterClient); + if (api.Versions == null || api.Versions.Length == 0) + { + Logger.LogInformation("No versions found for {api}", api.Properties?.Title); + continue; + } + + foreach (var versionFromApiCenter in api.Versions) + { + Debug.Assert(versionFromApiCenter.Id is not null); + + await versionFromApiCenter.LoadDefinitionsAsync(_apiCenterClient); + if (versionFromApiCenter.Definitions == null || + versionFromApiCenter.Definitions.Length == 0) + { + Logger.LogDebug("No definitions found for version {versionId}", versionFromApiCenter.Id); + continue; + } + + var definitions = new List(); + foreach (var definitionFromApiCenter in versionFromApiCenter.Definitions) + { + Debug.Assert(definitionFromApiCenter.Id is not null); + + await definitionFromApiCenter.LoadOpenApiDefinitionAsync(_apiCenterClient, Logger); + + if (definitionFromApiCenter.Definition is null) + { + Logger.LogDebug("API definition not found for {definitionId}", definitionFromApiCenter.Id); + continue; + } + + if (!definitionFromApiCenter.Definition.Servers.Any()) + { + Logger.LogDebug("No servers found for API definition {definitionId}", definitionFromApiCenter.Id); + continue; + } + + definitions.Add(definitionFromApiCenter); + } + + versionFromApiCenter.Definitions = [.. definitions]; + } + } + + Logger.LogInformation("Analyzing recorded requests..."); + + var report = new ApiCenterProductionVersionPluginReport(); + + foreach (var request in interceptedRequests) + { + var methodAndUrlString = request.Message; + var methodAndUrl = methodAndUrlString.Split(' '); + var (method, url) = (methodAndUrl[0], methodAndUrl[1]); + if (method == "OPTIONS") + { + continue; + } + + var api = _apis.FindApiByUrl(url, Logger); + if (api == null) + { + report.Add(new() + { + Method = method, + Url = url, + Status = ApiCenterProductionVersionPluginReportItemStatus.NotRegistered + }); + continue; + } + + var version = api.GetVersion(request, url, Logger); + if (version is null) + { + report.Add(new() + { + Method = method, + Url = url, + Status = ApiCenterProductionVersionPluginReportItemStatus.NotRegistered + }); + continue; + } + + Debug.Assert(version.Properties is not null); + var lifecycleStage = version.Properties.LifecycleStage; + + if (lifecycleStage != ApiLifecycleStage.Production) + { + Debug.Assert(api.Versions is not null); + + var productionVersions = api.Versions + .Where(v => v.Properties?.LifecycleStage == ApiLifecycleStage.Production) + .Select(v => v.Properties?.Title); + + var recommendation = productionVersions.Any() ? + string.Format("Request {0} uses API version {1} which is defined as {2}. Upgrade to a production version of the API. Recommended versions: {3}", methodAndUrlString, api.Versions.First(v => v.Properties?.LifecycleStage == lifecycleStage).Properties?.Title, lifecycleStage, string.Join(", ", productionVersions)) : + string.Format("Request {0} uses API version {1} which is defined as {2}.", methodAndUrlString, api.Versions.First(v => v.Properties?.LifecycleStage == lifecycleStage).Properties?.Title, lifecycleStage); + + Logger.LogWarning(recommendation); + report.Add(new() + { + Method = method, + Url = url, + Status = ApiCenterProductionVersionPluginReportItemStatus.NonProduction, + Recommendation = recommendation + }); + } + else + { + report.Add(new() + { + Method = method, + Url = url, + Status = ApiCenterProductionVersionPluginReportItemStatus.Production + }); + } + } + + StoreReport(report, e); + } } \ No newline at end of file diff --git a/dev-proxy-plugins/RequestLogs/ExecutionSummaryPlugin.cs b/dev-proxy-plugins/RequestLogs/ExecutionSummaryPlugin.cs index b9101ede..aa1cfc94 100644 --- a/dev-proxy-plugins/RequestLogs/ExecutionSummaryPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/ExecutionSummaryPlugin.cs @@ -1,244 +1,244 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using System.CommandLine; -using System.CommandLine.Invocation; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -public abstract class ExecutionSummaryPluginReportBase -{ - public required Dictionary>> Data { get; init; } - public required IEnumerable Logs { get; init; } -} - -public class ExecutionSummaryPluginReportByUrl : ExecutionSummaryPluginReportBase; -public class ExecutionSummaryPluginReportByMessageType : ExecutionSummaryPluginReportBase; - -internal enum SummaryGroupBy -{ - Url, - MessageType -} - -internal class ExecutionSummaryPluginConfiguration -{ - public SummaryGroupBy GroupBy { get; set; } = SummaryGroupBy.Url; -} - -public class ExecutionSummaryPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(ExecutionSummaryPlugin); - private readonly ExecutionSummaryPluginConfiguration _configuration = new(); - private static readonly string _groupByOptionName = "--summary-group-by"; - private const string _requestsInterceptedMessage = "Requests intercepted"; - private const string _requestsPassedThroughMessage = "Requests passed through"; - - public override Option[] GetOptions() - { - var groupBy = new Option(_groupByOptionName, "Specifies how the information should be grouped in the summary. Available options: `url` (default), `messageType`.") - { - ArgumentHelpName = "summary-group-by" - }; - groupBy.AddValidator(input => - { - if (!Enum.TryParse(input.Tokens[0].Value, true, out var groupBy)) - { - input.ErrorMessage = $"{input.Tokens[0].Value} is not a valid option to group by. Allowed values are: {string.Join(", ", Enum.GetNames(typeof(SummaryGroupBy)))}"; - } - }); - - return [groupBy]; - } - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - - PluginEvents.OptionsLoaded += OnOptionsLoaded; - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private void OnOptionsLoaded(object? sender, OptionsLoadedArgs e) - { - InvocationContext context = e.Context; - - var groupBy = context.ParseResult.GetValueForOption(_groupByOptionName, e.Options); - if (groupBy is not null) - { - _configuration.GroupBy = groupBy.Value; - } - } - - private Task AfterRecordingStopAsync(object? sender, RecordingArgs e) - { - if (!e.RequestLogs.Any()) - { - return Task.CompletedTask; - } - - ExecutionSummaryPluginReportBase report = _configuration.GroupBy switch - { - SummaryGroupBy.Url => new ExecutionSummaryPluginReportByUrl { Data = GetGroupedByUrlData(e.RequestLogs), Logs = e.RequestLogs }, - SummaryGroupBy.MessageType => new ExecutionSummaryPluginReportByMessageType { Data = GetGroupedByMessageTypeData(e.RequestLogs), Logs = e.RequestLogs }, - _ => throw new NotImplementedException() - }; - - StoreReport(report, e); - - return Task.CompletedTask; - } - - // in this method we're producing the follow data structure - // request > message type > (count) message - private Dictionary>> GetGroupedByUrlData(IEnumerable requestLogs) - { - var data = new Dictionary>>(); - - foreach (var log in requestLogs) - { - var message = GetRequestMessage(log); - if (log.MessageType == MessageType.InterceptedResponse) - { - // ignore intercepted response messages - continue; - } - - if (log.MessageType == MessageType.InterceptedRequest) - { - var request = GetMethodAndUrl(log); - if (!data.ContainsKey(request)) - { - data.Add(request, []); - } - - continue; - } - - // last line of the message is the method and URL of the request - var methodAndUrl = GetMethodAndUrl(log); - var readableMessageType = GetReadableMessageTypeForSummary(log.MessageType); - if (!data[methodAndUrl].TryGetValue(readableMessageType, out Dictionary? value)) - { - value = ([]); - data[methodAndUrl].Add(readableMessageType, value); - } - - if (value.TryGetValue(message, out int val)) - { - value[message] = ++val; - } - else - { - value.Add(message, 1); - } - } - - return data; - } - - // in this method we're producing the follow data structure - // message type > message > (count) request - private Dictionary>> GetGroupedByMessageTypeData(IEnumerable requestLogs) - { - var data = new Dictionary>>(); - - foreach (var log in requestLogs) - { - if (log.MessageType == MessageType.InterceptedResponse) - { - // ignore intercepted response messages - continue; - } - - var readableMessageType = GetReadableMessageTypeForSummary(log.MessageType); - if (!data.TryGetValue(readableMessageType, out Dictionary>? value)) - { - value = []; - data.Add(readableMessageType, value); - - if (log.MessageType == MessageType.InterceptedRequest || - log.MessageType == MessageType.PassedThrough) - { - // intercepted and passed through requests don't have - // a sub-grouping so let's repeat the message type - // to keep the same data shape - data[readableMessageType].Add(readableMessageType, []); - } - } - - var message = GetRequestMessage(log); - if (log.MessageType == MessageType.InterceptedRequest || - log.MessageType == MessageType.PassedThrough) - { - // for passed through requests we need to log the URL rather than the - // fixed message - if (log.MessageType == MessageType.PassedThrough) - { - message = GetMethodAndUrl(log); - } - - if (!value[readableMessageType].ContainsKey(message)) - { - value[readableMessageType].Add(message, 1); - } - else - { - value[readableMessageType][message]++; - } - continue; - } - - if (!value.TryGetValue(message, out Dictionary? val)) - { - val = ([]); - value.Add(message, val); - } - var methodAndUrl = GetMethodAndUrl(log); - if (value[message].ContainsKey(methodAndUrl)) - { - value[message][methodAndUrl]++; - } - else - { - value[message].Add(methodAndUrl, 1); - } - } - - return data; - } - - private static string GetRequestMessage(RequestLog requestLog) - { - return string.Join(' ', requestLog.MessageLines); - } - - private static string GetMethodAndUrl(RequestLog requestLog) - { - if (requestLog.Context is not null) - { - return $"{requestLog.Context.Session.HttpClient.Request.Method} {requestLog.Context.Session.HttpClient.Request.RequestUri}"; - } - else - { - return "Undefined"; - } - } - - private static string GetReadableMessageTypeForSummary(MessageType messageType) => messageType switch - { - MessageType.Chaos => "Requests with chaos", - MessageType.Failed => "Failures", - MessageType.InterceptedRequest => _requestsInterceptedMessage, - MessageType.Mocked => "Requests mocked", - MessageType.PassedThrough => _requestsPassedThroughMessage, - MessageType.Tip => "Tips", - MessageType.Warning => "Warnings", - _ => "Unknown" - }; -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using System.CommandLine; +using System.CommandLine.Invocation; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +public abstract class ExecutionSummaryPluginReportBase +{ + public required Dictionary>> Data { get; init; } + public required IEnumerable Logs { get; init; } +} + +public class ExecutionSummaryPluginReportByUrl : ExecutionSummaryPluginReportBase; +public class ExecutionSummaryPluginReportByMessageType : ExecutionSummaryPluginReportBase; + +internal enum SummaryGroupBy +{ + Url, + MessageType +} + +internal class ExecutionSummaryPluginConfiguration +{ + public SummaryGroupBy GroupBy { get; set; } = SummaryGroupBy.Url; +} + +public class ExecutionSummaryPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(ExecutionSummaryPlugin); + private readonly ExecutionSummaryPluginConfiguration _configuration = new(); + private static readonly string _groupByOptionName = "--summary-group-by"; + private const string _requestsInterceptedMessage = "Requests intercepted"; + private const string _requestsPassedThroughMessage = "Requests passed through"; + + public override Option[] GetOptions() + { + var groupBy = new Option(_groupByOptionName, "Specifies how the information should be grouped in the summary. Available options: `url` (default), `messageType`.") + { + ArgumentHelpName = "summary-group-by" + }; + groupBy.AddValidator(input => + { + if (!Enum.TryParse(input.Tokens[0].Value, true, out var groupBy)) + { + input.ErrorMessage = $"{input.Tokens[0].Value} is not a valid option to group by. Allowed values are: {string.Join(", ", Enum.GetNames(typeof(SummaryGroupBy)))}"; + } + }); + + return [groupBy]; + } + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + + PluginEvents.OptionsLoaded += OnOptionsLoaded; + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private void OnOptionsLoaded(object? sender, OptionsLoadedArgs e) + { + InvocationContext context = e.Context; + + var groupBy = context.ParseResult.GetValueForOption(_groupByOptionName, e.Options); + if (groupBy is not null) + { + _configuration.GroupBy = groupBy.Value; + } + } + + private Task AfterRecordingStopAsync(object? sender, RecordingArgs e) + { + if (!e.RequestLogs.Any()) + { + return Task.CompletedTask; + } + + ExecutionSummaryPluginReportBase report = _configuration.GroupBy switch + { + SummaryGroupBy.Url => new ExecutionSummaryPluginReportByUrl { Data = GetGroupedByUrlData(e.RequestLogs), Logs = e.RequestLogs }, + SummaryGroupBy.MessageType => new ExecutionSummaryPluginReportByMessageType { Data = GetGroupedByMessageTypeData(e.RequestLogs), Logs = e.RequestLogs }, + _ => throw new NotImplementedException() + }; + + StoreReport(report, e); + + return Task.CompletedTask; + } + + // in this method we're producing the follow data structure + // request > message type > (count) message + private Dictionary>> GetGroupedByUrlData(IEnumerable requestLogs) + { + var data = new Dictionary>>(); + + foreach (var log in requestLogs) + { + var message = GetRequestMessage(log); + if (log.MessageType == MessageType.InterceptedResponse) + { + // ignore intercepted response messages + continue; + } + + if (log.MessageType == MessageType.InterceptedRequest) + { + var request = GetMethodAndUrl(log); + if (!data.ContainsKey(request)) + { + data.Add(request, []); + } + + continue; + } + + // last line of the message is the method and URL of the request + var methodAndUrl = GetMethodAndUrl(log); + var readableMessageType = GetReadableMessageTypeForSummary(log.MessageType); + if (!data[methodAndUrl].TryGetValue(readableMessageType, out Dictionary? value)) + { + value = ([]); + data[methodAndUrl].Add(readableMessageType, value); + } + + if (value.TryGetValue(message, out int val)) + { + value[message] = ++val; + } + else + { + value.Add(message, 1); + } + } + + return data; + } + + // in this method we're producing the follow data structure + // message type > message > (count) request + private Dictionary>> GetGroupedByMessageTypeData(IEnumerable requestLogs) + { + var data = new Dictionary>>(); + + foreach (var log in requestLogs) + { + if (log.MessageType == MessageType.InterceptedResponse) + { + // ignore intercepted response messages + continue; + } + + var readableMessageType = GetReadableMessageTypeForSummary(log.MessageType); + if (!data.TryGetValue(readableMessageType, out Dictionary>? value)) + { + value = []; + data.Add(readableMessageType, value); + + if (log.MessageType == MessageType.InterceptedRequest || + log.MessageType == MessageType.PassedThrough) + { + // intercepted and passed through requests don't have + // a sub-grouping so let's repeat the message type + // to keep the same data shape + data[readableMessageType].Add(readableMessageType, []); + } + } + + var message = GetRequestMessage(log); + if (log.MessageType == MessageType.InterceptedRequest || + log.MessageType == MessageType.PassedThrough) + { + // for passed through requests we need to log the URL rather than the + // fixed message + if (log.MessageType == MessageType.PassedThrough) + { + message = GetMethodAndUrl(log); + } + + if (!value[readableMessageType].ContainsKey(message)) + { + value[readableMessageType].Add(message, 1); + } + else + { + value[readableMessageType][message]++; + } + continue; + } + + if (!value.TryGetValue(message, out Dictionary? val)) + { + val = ([]); + value.Add(message, val); + } + var methodAndUrl = GetMethodAndUrl(log); + if (value[message].ContainsKey(methodAndUrl)) + { + value[message][methodAndUrl]++; + } + else + { + value[message].Add(methodAndUrl, 1); + } + } + + return data; + } + + private static string GetRequestMessage(RequestLog requestLog) + { + return string.Join(' ', requestLog.Message); + } + + private static string GetMethodAndUrl(RequestLog requestLog) + { + if (requestLog.Context is not null) + { + return $"{requestLog.Context.Session.HttpClient.Request.Method} {requestLog.Context.Session.HttpClient.Request.RequestUri}"; + } + else + { + return "Undefined"; + } + } + + private static string GetReadableMessageTypeForSummary(MessageType messageType) => messageType switch + { + MessageType.Chaos => "Requests with chaos", + MessageType.Failed => "Failures", + MessageType.InterceptedRequest => _requestsInterceptedMessage, + MessageType.Mocked => "Requests mocked", + MessageType.PassedThrough => _requestsPassedThroughMessage, + MessageType.Tip => "Tips", + MessageType.Warning => "Warnings", + _ => "Unknown" + }; +} diff --git a/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsGuidancePlugin.cs b/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsGuidancePlugin.cs index 8dd00a78..239757f2 100644 --- a/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsGuidancePlugin.cs +++ b/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsGuidancePlugin.cs @@ -1,364 +1,364 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions; -using System.IdentityModel.Tokens.Jwt; -using System.Net.Http.Json; -using System.Text.Json; -using Microsoft.DevProxy.Plugins.MinimalPermissions; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -public class GraphMinimalPermissionsGuidancePluginReport -{ - public GraphMinimalPermissionsInfo? DelegatedPermissions { get; set; } - public GraphMinimalPermissionsInfo? ApplicationPermissions { get; set; } - public IEnumerable? ExcludedPermissions { get; set; } -} - -public class GraphMinimalPermissionsOperationInfo -{ - public string Method { get; set; } = string.Empty; - public string Endpoint { get; set; } = string.Empty; -} - -public class GraphMinimalPermissionsInfo -{ - public IEnumerable MinimalPermissions { get; set; } = []; - public IEnumerable PermissionsFromTheToken { get; set; } = []; - public IEnumerable ExcessPermissions { get; set; } = []; - public GraphMinimalPermissionsOperationInfo[] Operations { get; set; } = []; -} - -internal class GraphMinimalPermissionsGuidancePluginConfiguration -{ - public IEnumerable? PermissionsToExclude { get; set; } -} - -public class GraphMinimalPermissionsGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(GraphMinimalPermissionsGuidancePlugin); - private readonly GraphMinimalPermissionsGuidancePluginConfiguration _configuration = new(); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - // we need to do it this way because .NET doesn't distinguish between - // an empty array and a null value and we want to be able to tell - // if the user hasn't specified a value and we should use the default - // set or if they have specified an empty array and we shouldn't exclude - // any permissions - if (_configuration.PermissionsToExclude is null) - { - _configuration.PermissionsToExclude = ["profile", "openid", "offline_access", "email"]; - } - else { - // remove empty strings - _configuration.PermissionsToExclude = _configuration.PermissionsToExclude.Where(p => !string.IsNullOrEmpty(p)); - } - - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) - { - if (!e.RequestLogs.Any()) - { - return; - } - - var methodAndUrlComparer = new MethodAndUrlComparer(); - var delegatedEndpoints = new List<(string method, string url)>(); - var applicationEndpoints = new List<(string method, string url)>(); - - // scope for delegated permissions - IEnumerable scopesToEvaluate = []; - // roles for application permissions - IEnumerable rolesToEvaluate = []; - - foreach (var request in e.RequestLogs) - { - if (request.MessageType != MessageType.InterceptedRequest) - { - continue; - } - - var methodAndUrlString = request.MessageLines.First(); - var methodAndUrl = GetMethodAndUrl(methodAndUrlString); - if (methodAndUrl.method.Equals("OPTIONS", StringComparison.OrdinalIgnoreCase)) - { - continue; - } - - var requestsFromBatch = Array.Empty<(string method, string url)>(); - - var uri = new Uri(methodAndUrl.url); - if (!ProxyUtils.IsGraphUrl(uri)) - { - continue; - } - - if (ProxyUtils.IsGraphBatchUrl(uri)) - { - var graphVersion = ProxyUtils.IsGraphBetaUrl(uri) ? "beta" : "v1.0"; - requestsFromBatch = GetRequestsFromBatch(request.Context?.Session.HttpClient.Request.BodyString!, graphVersion, uri.Host); - } - else - { - methodAndUrl = (methodAndUrl.method, GetTokenizedUrl(methodAndUrl.url)); - } - - var scopesAndType = GetPermissionsAndType(request); - if (scopesAndType.type == GraphPermissionsType.Delegated) - { - // use the scopes from the last request in case the app is using incremental consent - scopesToEvaluate = scopesAndType.permissions; - - if (ProxyUtils.IsGraphBatchUrl(uri)) - { - delegatedEndpoints.AddRange(requestsFromBatch); - } - else - { - delegatedEndpoints.Add(methodAndUrl); - } - } - else - { - // skip empty roles which are returned in case we couldn't get permissions information - // - // application permissions are always the same because they come from app reg - // so we can just use the first request that has them - if (scopesAndType.permissions.Any() && !rolesToEvaluate.Any()) - { - rolesToEvaluate = scopesAndType.permissions; - - if (ProxyUtils.IsGraphBatchUrl(uri)) - { - applicationEndpoints.AddRange(requestsFromBatch); - } - else - { - applicationEndpoints.Add(methodAndUrl); - } - } - } - } - - // Remove duplicates - delegatedEndpoints = delegatedEndpoints.Distinct(methodAndUrlComparer).ToList(); - applicationEndpoints = applicationEndpoints.Distinct(methodAndUrlComparer).ToList(); - - if (delegatedEndpoints.Count == 0 && applicationEndpoints.Count == 0) - { - return; - } - - var report = new GraphMinimalPermissionsGuidancePluginReport - { - ExcludedPermissions = _configuration.PermissionsToExclude - }; - - Logger.LogWarning("This plugin is in preview and may not return the correct results.\r\nPlease review the permissions and test your app before using them in production.\r\nIf you have any feedback, please open an issue at https://aka.ms/devproxy/issue.\r\n"); - - if (_configuration.PermissionsToExclude is not null && - _configuration.PermissionsToExclude.Any()) - { - Logger.LogInformation("Excluding the following permissions: {permissions}", string.Join(", ", _configuration.PermissionsToExclude)); - } - - if (delegatedEndpoints.Count > 0) - { - var delegatedPermissionsInfo = new GraphMinimalPermissionsInfo(); - report.DelegatedPermissions = delegatedPermissionsInfo; - - Logger.LogInformation("Evaluating delegated permissions for: {endpoints}", string.Join(", ", delegatedEndpoints.Select(e => $"{e.method} {e.url}"))); - - await EvaluateMinimalScopesAsync(delegatedEndpoints, scopesToEvaluate, GraphPermissionsType.Delegated, delegatedPermissionsInfo); - } - - if (applicationEndpoints.Count > 0) - { - var applicationPermissionsInfo = new GraphMinimalPermissionsInfo(); - report.ApplicationPermissions = applicationPermissionsInfo; - - Logger.LogInformation("Evaluating application permissions for: {endpoints}", string.Join(", ", applicationEndpoints.Select(e => $"{e.method} {e.url}"))); - - await EvaluateMinimalScopesAsync(applicationEndpoints, rolesToEvaluate, GraphPermissionsType.Application, applicationPermissionsInfo); - } - - StoreReport(report, e); - } - - private static (string method, string url)[] GetRequestsFromBatch(string batchBody, string graphVersion, string graphHostName) - { - var requests = new List<(string method, string url)>(); - - if (string.IsNullOrEmpty(batchBody)) - { - return [.. requests]; - } - - try - { - var batch = JsonSerializer.Deserialize(batchBody, ProxyUtils.JsonSerializerOptions); - if (batch == null) - { - return [.. requests]; - } - - foreach (var request in batch.Requests) - { - try - { - var method = request.Method; - var url = request.Url; - var absoluteUrl = $"https://{graphHostName}/{graphVersion}{url}"; - requests.Add((method, GetTokenizedUrl(absoluteUrl))); - } - catch { } - } - } - catch { } - - return [.. requests]; - } - - /// - /// Returns permissions and type (delegated or application) from the access token - /// used on the request. - /// If it can't get the permissions, returns PermissionType.Application - /// and an empty array - /// - private static (GraphPermissionsType type, IEnumerable permissions) GetPermissionsAndType(RequestLog request) - { - var authHeader = request.Context?.Session.HttpClient.Request.Headers.GetFirstHeader("Authorization"); - if (authHeader == null) - { - return (GraphPermissionsType.Application, []); - } - - var token = authHeader.Value.Replace("Bearer ", string.Empty); - var tokenChunks = token.Split('.'); - if (tokenChunks.Length != 3) - { - return (GraphPermissionsType.Application, []); - } - - try - { - var handler = new JwtSecurityTokenHandler(); - var jwtSecurityToken = handler.ReadJwtToken(token); - - var scopeClaim = jwtSecurityToken.Claims.FirstOrDefault(c => c.Type == "scp"); - if (scopeClaim == null) - { - // possibly an application token - // roles is an array so we need to handle it differently - var roles = jwtSecurityToken.Claims - .Where(c => c.Type == "roles") - .Select(c => c.Value); - if (!roles.Any()) - { - return (GraphPermissionsType.Application, []); - } - else - { - return (GraphPermissionsType.Application, roles); - } - } - else - { - return (GraphPermissionsType.Delegated, scopeClaim.Value.Split(' ')); - } - } - catch - { - return (GraphPermissionsType.Application, []); - } - } - - private async Task EvaluateMinimalScopesAsync(IEnumerable<(string method, string url)> endpoints, IEnumerable permissionsFromAccessToken, GraphPermissionsType scopeType, GraphMinimalPermissionsInfo permissionsInfo) - { - var payload = endpoints.Select(e => new GraphRequestInfo { Method = e.method, Url = e.url }); - - permissionsInfo.Operations = endpoints.Select(e => new GraphMinimalPermissionsOperationInfo - { - Method = e.method, - Endpoint = e.url - }).ToArray(); - permissionsInfo.PermissionsFromTheToken = permissionsFromAccessToken; - - try - { - var url = $"https://graphexplorerapi.azurewebsites.net/permissions?scopeType={GraphUtils.GetScopeTypeString(scopeType)}"; - using var client = new HttpClient(); - var stringPayload = JsonSerializer.Serialize(payload, ProxyUtils.JsonSerializerOptions); - Logger.LogDebug(string.Format("Calling {0} with payload{1}{2}", url, Environment.NewLine, stringPayload)); - - var response = await client.PostAsJsonAsync(url, payload); - var content = await response.Content.ReadAsStringAsync(); - - Logger.LogDebug(string.Format("Response:{0}{1}", Environment.NewLine, content)); - - var resultsAndErrors = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - var minimalPermissions = resultsAndErrors?.Results?.Select(p => p.Value) ?? []; - var errors = resultsAndErrors?.Errors?.Select(e => $"- {e.Url} ({e.Message})") ?? []; - - if (scopeType == GraphPermissionsType.Delegated) - { - minimalPermissions = await GraphUtils.UpdateUserScopesAsync(minimalPermissions, endpoints, scopeType, Logger); - } - - if (minimalPermissions.Any()) - { - var excessPermissions = permissionsFromAccessToken - .Except(_configuration.PermissionsToExclude ?? []) - .Where(p => !minimalPermissions.Contains(p)); - - permissionsInfo.MinimalPermissions = minimalPermissions; - permissionsInfo.ExcessPermissions = excessPermissions; - - Logger.LogInformation("Minimal permissions: {minimalPermissions}", string.Join(", ", minimalPermissions)); - Logger.LogInformation("Permissions on the token: {tokenPermissions}", string.Join(", ", permissionsFromAccessToken)); - - if (excessPermissions.Any()) - { - Logger.LogWarning("The following permissions are unnecessary: {permissions}", string.Join(", ", excessPermissions)); - } - else - { - Logger.LogInformation("The token has the minimal permissions required."); - } - } - if (errors.Any()) - { - Logger.LogError("Couldn't determine minimal permissions for the following URLs: {errors}", string.Join(", ", errors)); - } - } - catch (Exception ex) - { - Logger.LogError(ex, "An error has occurred while retrieving minimal permissions: {message}", ex.Message); - } - } - - private static (string method, string url) GetMethodAndUrl(string message) - { - var info = message.Split(" "); - if (info.Length > 2) - { - info = [info[0], string.Join(" ", info.Skip(1))]; - } - return (method: info[0], url: info[1]); - } - - private static string GetTokenizedUrl(string absoluteUrl) - { - var sanitizedUrl = ProxyUtils.SanitizeUrl(absoluteUrl); - return "/" + string.Join("", new Uri(sanitizedUrl).Segments.Skip(2).Select(Uri.UnescapeDataString)); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions; +using System.IdentityModel.Tokens.Jwt; +using System.Net.Http.Json; +using System.Text.Json; +using Microsoft.DevProxy.Plugins.MinimalPermissions; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +public class GraphMinimalPermissionsGuidancePluginReport +{ + public GraphMinimalPermissionsInfo? DelegatedPermissions { get; set; } + public GraphMinimalPermissionsInfo? ApplicationPermissions { get; set; } + public IEnumerable? ExcludedPermissions { get; set; } +} + +public class GraphMinimalPermissionsOperationInfo +{ + public string Method { get; set; } = string.Empty; + public string Endpoint { get; set; } = string.Empty; +} + +public class GraphMinimalPermissionsInfo +{ + public IEnumerable MinimalPermissions { get; set; } = []; + public IEnumerable PermissionsFromTheToken { get; set; } = []; + public IEnumerable ExcessPermissions { get; set; } = []; + public GraphMinimalPermissionsOperationInfo[] Operations { get; set; } = []; +} + +internal class GraphMinimalPermissionsGuidancePluginConfiguration +{ + public IEnumerable? PermissionsToExclude { get; set; } +} + +public class GraphMinimalPermissionsGuidancePlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(GraphMinimalPermissionsGuidancePlugin); + private readonly GraphMinimalPermissionsGuidancePluginConfiguration _configuration = new(); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + // we need to do it this way because .NET doesn't distinguish between + // an empty array and a null value and we want to be able to tell + // if the user hasn't specified a value and we should use the default + // set or if they have specified an empty array and we shouldn't exclude + // any permissions + if (_configuration.PermissionsToExclude is null) + { + _configuration.PermissionsToExclude = ["profile", "openid", "offline_access", "email"]; + } + else { + // remove empty strings + _configuration.PermissionsToExclude = _configuration.PermissionsToExclude.Where(p => !string.IsNullOrEmpty(p)); + } + + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) + { + if (!e.RequestLogs.Any()) + { + return; + } + + var methodAndUrlComparer = new MethodAndUrlComparer(); + var delegatedEndpoints = new List<(string method, string url)>(); + var applicationEndpoints = new List<(string method, string url)>(); + + // scope for delegated permissions + IEnumerable scopesToEvaluate = []; + // roles for application permissions + IEnumerable rolesToEvaluate = []; + + foreach (var request in e.RequestLogs) + { + if (request.MessageType != MessageType.InterceptedRequest) + { + continue; + } + + var methodAndUrlString = request.Message; + var methodAndUrl = GetMethodAndUrl(methodAndUrlString); + if (methodAndUrl.method.Equals("OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + var requestsFromBatch = Array.Empty<(string method, string url)>(); + + var uri = new Uri(methodAndUrl.url); + if (!ProxyUtils.IsGraphUrl(uri)) + { + continue; + } + + if (ProxyUtils.IsGraphBatchUrl(uri)) + { + var graphVersion = ProxyUtils.IsGraphBetaUrl(uri) ? "beta" : "v1.0"; + requestsFromBatch = GetRequestsFromBatch(request.Context?.Session.HttpClient.Request.BodyString!, graphVersion, uri.Host); + } + else + { + methodAndUrl = (methodAndUrl.method, GetTokenizedUrl(methodAndUrl.url)); + } + + var scopesAndType = GetPermissionsAndType(request); + if (scopesAndType.type == GraphPermissionsType.Delegated) + { + // use the scopes from the last request in case the app is using incremental consent + scopesToEvaluate = scopesAndType.permissions; + + if (ProxyUtils.IsGraphBatchUrl(uri)) + { + delegatedEndpoints.AddRange(requestsFromBatch); + } + else + { + delegatedEndpoints.Add(methodAndUrl); + } + } + else + { + // skip empty roles which are returned in case we couldn't get permissions information + // + // application permissions are always the same because they come from app reg + // so we can just use the first request that has them + if (scopesAndType.permissions.Any() && !rolesToEvaluate.Any()) + { + rolesToEvaluate = scopesAndType.permissions; + + if (ProxyUtils.IsGraphBatchUrl(uri)) + { + applicationEndpoints.AddRange(requestsFromBatch); + } + else + { + applicationEndpoints.Add(methodAndUrl); + } + } + } + } + + // Remove duplicates + delegatedEndpoints = delegatedEndpoints.Distinct(methodAndUrlComparer).ToList(); + applicationEndpoints = applicationEndpoints.Distinct(methodAndUrlComparer).ToList(); + + if (delegatedEndpoints.Count == 0 && applicationEndpoints.Count == 0) + { + return; + } + + var report = new GraphMinimalPermissionsGuidancePluginReport + { + ExcludedPermissions = _configuration.PermissionsToExclude + }; + + Logger.LogWarning("This plugin is in preview and may not return the correct results.\r\nPlease review the permissions and test your app before using them in production.\r\nIf you have any feedback, please open an issue at https://aka.ms/devproxy/issue.\r\n"); + + if (_configuration.PermissionsToExclude is not null && + _configuration.PermissionsToExclude.Any()) + { + Logger.LogInformation("Excluding the following permissions: {permissions}", string.Join(", ", _configuration.PermissionsToExclude)); + } + + if (delegatedEndpoints.Count > 0) + { + var delegatedPermissionsInfo = new GraphMinimalPermissionsInfo(); + report.DelegatedPermissions = delegatedPermissionsInfo; + + Logger.LogInformation("Evaluating delegated permissions for: {endpoints}", string.Join(", ", delegatedEndpoints.Select(e => $"{e.method} {e.url}"))); + + await EvaluateMinimalScopesAsync(delegatedEndpoints, scopesToEvaluate, GraphPermissionsType.Delegated, delegatedPermissionsInfo); + } + + if (applicationEndpoints.Count > 0) + { + var applicationPermissionsInfo = new GraphMinimalPermissionsInfo(); + report.ApplicationPermissions = applicationPermissionsInfo; + + Logger.LogInformation("Evaluating application permissions for: {endpoints}", string.Join(", ", applicationEndpoints.Select(e => $"{e.method} {e.url}"))); + + await EvaluateMinimalScopesAsync(applicationEndpoints, rolesToEvaluate, GraphPermissionsType.Application, applicationPermissionsInfo); + } + + StoreReport(report, e); + } + + private static (string method, string url)[] GetRequestsFromBatch(string batchBody, string graphVersion, string graphHostName) + { + var requests = new List<(string method, string url)>(); + + if (string.IsNullOrEmpty(batchBody)) + { + return [.. requests]; + } + + try + { + var batch = JsonSerializer.Deserialize(batchBody, ProxyUtils.JsonSerializerOptions); + if (batch == null) + { + return [.. requests]; + } + + foreach (var request in batch.Requests) + { + try + { + var method = request.Method; + var url = request.Url; + var absoluteUrl = $"https://{graphHostName}/{graphVersion}{url}"; + requests.Add((method, GetTokenizedUrl(absoluteUrl))); + } + catch { } + } + } + catch { } + + return [.. requests]; + } + + /// + /// Returns permissions and type (delegated or application) from the access token + /// used on the request. + /// If it can't get the permissions, returns PermissionType.Application + /// and an empty array + /// + private static (GraphPermissionsType type, IEnumerable permissions) GetPermissionsAndType(RequestLog request) + { + var authHeader = request.Context?.Session.HttpClient.Request.Headers.GetFirstHeader("Authorization"); + if (authHeader == null) + { + return (GraphPermissionsType.Application, []); + } + + var token = authHeader.Value.Replace("Bearer ", string.Empty); + var tokenChunks = token.Split('.'); + if (tokenChunks.Length != 3) + { + return (GraphPermissionsType.Application, []); + } + + try + { + var handler = new JwtSecurityTokenHandler(); + var jwtSecurityToken = handler.ReadJwtToken(token); + + var scopeClaim = jwtSecurityToken.Claims.FirstOrDefault(c => c.Type == "scp"); + if (scopeClaim == null) + { + // possibly an application token + // roles is an array so we need to handle it differently + var roles = jwtSecurityToken.Claims + .Where(c => c.Type == "roles") + .Select(c => c.Value); + if (!roles.Any()) + { + return (GraphPermissionsType.Application, []); + } + else + { + return (GraphPermissionsType.Application, roles); + } + } + else + { + return (GraphPermissionsType.Delegated, scopeClaim.Value.Split(' ')); + } + } + catch + { + return (GraphPermissionsType.Application, []); + } + } + + private async Task EvaluateMinimalScopesAsync(IEnumerable<(string method, string url)> endpoints, IEnumerable permissionsFromAccessToken, GraphPermissionsType scopeType, GraphMinimalPermissionsInfo permissionsInfo) + { + var payload = endpoints.Select(e => new GraphRequestInfo { Method = e.method, Url = e.url }); + + permissionsInfo.Operations = endpoints.Select(e => new GraphMinimalPermissionsOperationInfo + { + Method = e.method, + Endpoint = e.url + }).ToArray(); + permissionsInfo.PermissionsFromTheToken = permissionsFromAccessToken; + + try + { + var url = $"https://graphexplorerapi.azurewebsites.net/permissions?scopeType={GraphUtils.GetScopeTypeString(scopeType)}"; + using var client = new HttpClient(); + var stringPayload = JsonSerializer.Serialize(payload, ProxyUtils.JsonSerializerOptions); + Logger.LogDebug(string.Format("Calling {0} with payload{1}{2}", url, Environment.NewLine, stringPayload)); + + var response = await client.PostAsJsonAsync(url, payload); + var content = await response.Content.ReadAsStringAsync(); + + Logger.LogDebug(string.Format("Response:{0}{1}", Environment.NewLine, content)); + + var resultsAndErrors = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + var minimalPermissions = resultsAndErrors?.Results?.Select(p => p.Value) ?? []; + var errors = resultsAndErrors?.Errors?.Select(e => $"- {e.Url} ({e.Message})") ?? []; + + if (scopeType == GraphPermissionsType.Delegated) + { + minimalPermissions = await GraphUtils.UpdateUserScopesAsync(minimalPermissions, endpoints, scopeType, Logger); + } + + if (minimalPermissions.Any()) + { + var excessPermissions = permissionsFromAccessToken + .Except(_configuration.PermissionsToExclude ?? []) + .Where(p => !minimalPermissions.Contains(p)); + + permissionsInfo.MinimalPermissions = minimalPermissions; + permissionsInfo.ExcessPermissions = excessPermissions; + + Logger.LogInformation("Minimal permissions: {minimalPermissions}", string.Join(", ", minimalPermissions)); + Logger.LogInformation("Permissions on the token: {tokenPermissions}", string.Join(", ", permissionsFromAccessToken)); + + if (excessPermissions.Any()) + { + Logger.LogWarning("The following permissions are unnecessary: {permissions}", string.Join(", ", excessPermissions)); + } + else + { + Logger.LogInformation("The token has the minimal permissions required."); + } + } + if (errors.Any()) + { + Logger.LogError("Couldn't determine minimal permissions for the following URLs: {errors}", string.Join(", ", errors)); + } + } + catch (Exception ex) + { + Logger.LogError(ex, "An error has occurred while retrieving minimal permissions: {message}", ex.Message); + } + } + + private static (string method, string url) GetMethodAndUrl(string message) + { + var info = message.Split(" "); + if (info.Length > 2) + { + info = [info[0], string.Join(" ", info.Skip(1))]; + } + return (method: info[0], url: info[1]); + } + + private static string GetTokenizedUrl(string absoluteUrl) + { + var sanitizedUrl = ProxyUtils.SanitizeUrl(absoluteUrl); + return "/" + string.Join("", new Uri(sanitizedUrl).Segments.Skip(2).Select(Uri.UnescapeDataString)); + } +} diff --git a/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsPlugin.cs b/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsPlugin.cs index 2ffa8386..a19df868 100644 --- a/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/GraphMinimalPermissionsPlugin.cs @@ -56,7 +56,7 @@ private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) continue; } - var methodAndUrlString = request.MessageLines.First(); + var methodAndUrlString = request.Message; var methodAndUrl = GetMethodAndUrl(methodAndUrlString); if (methodAndUrl.method.Equals("OPTIONS", StringComparison.OrdinalIgnoreCase)) { diff --git a/dev-proxy-plugins/RequestLogs/HttpFileGeneratorPlugin.cs b/dev-proxy-plugins/RequestLogs/HttpFileGeneratorPlugin.cs index 67d61462..d50374b3 100644 --- a/dev-proxy-plugins/RequestLogs/HttpFileGeneratorPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/HttpFileGeneratorPlugin.cs @@ -1,284 +1,284 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - - -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Text; -using System.Web; -using Microsoft.DevProxy.Abstractions; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -internal class HttpFile -{ - public Dictionary Variables { get; set; } = []; - public List Requests { get; set; } = []; - - public string Serialize() - { - var sb = new StringBuilder(); - - foreach (var variable in Variables) - { - sb.AppendLine($"@{variable.Key} = {variable.Value}"); - } - - foreach (var request in Requests) - { - sb.AppendLine(); - sb.AppendLine("###"); - sb.AppendLine(); - sb.AppendLine($"# @name {GetRequestName(request)}"); - sb.AppendLine(); - - sb.AppendLine($"{request.Method} {request.Url}"); - - foreach (var header in request.Headers) - { - sb.AppendLine($"{header.Name}: {header.Value}"); - } - - if (!string.IsNullOrEmpty(request.Body)) - { - sb.AppendLine(); - sb.AppendLine(request.Body); - } - } - - return sb.ToString(); - } - - private static string GetRequestName(HttpFileRequest request) - { - var url = new Uri(request.Url); - return $"{request.Method.ToLower()}{url.Segments.Last().Replace("/", "").ToPascalCase()}"; - } -} - -internal class HttpFileRequest -{ - public string Method { get; set; } = string.Empty; - public string Url { get; set; } = string.Empty; - public string? Body { get; set; } - public List Headers { get; set; } = []; -} - -internal class HttpFileRequestHeader -{ - public string Name { get; set; } = string.Empty; - public string Value { get; set; } = string.Empty; -} - -public class HttpFileGeneratorPluginReport : List -{ - public HttpFileGeneratorPluginReport() : base() { } - - public HttpFileGeneratorPluginReport(IEnumerable collection) : base(collection) { } -} - -internal class HttpFileGeneratorPluginConfiguration -{ - public bool IncludeOptionsRequests { get; set; } = false; -} - -public class HttpFileGeneratorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(HttpFileGeneratorPlugin); - public static readonly string GeneratedHttpFilesKey = "GeneratedHttpFiles"; - private HttpFileGeneratorPluginConfiguration _configuration = new(); - private readonly string[] headersToExtract = ["authorization", "key"]; - private readonly string[] queryParametersToExtract = ["key"]; - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) - { - Logger.LogInformation("Creating HTTP file from recorded requests..."); - - if (!e.RequestLogs.Any()) - { - Logger.LogDebug("No requests to process"); - return; - } - - var httpFile = await GetHttpRequestsAsync(e.RequestLogs); - DeduplicateRequests(httpFile); - ExtractVariables(httpFile); - - var fileName = $"requests_{DateTime.Now:yyyyMMddHHmmss}.http"; - Logger.LogDebug("Writing HTTP file to {fileName}...", fileName); - File.WriteAllText(fileName, httpFile.Serialize()); - Logger.LogInformation("Created HTTP file {fileName}", fileName); - - var generatedHttpFiles = new[] { fileName }; - StoreReport(new HttpFileGeneratorPluginReport(generatedHttpFiles), e); - - // store the generated HTTP files in the global data - // for use by other plugins - e.GlobalData[GeneratedHttpFilesKey] = generatedHttpFiles; - } - - private async Task GetHttpRequestsAsync(IEnumerable requestLogs) - { - var httpFile = new HttpFile(); - - foreach (var request in requestLogs) - { - if (request.MessageType != MessageType.InterceptedResponse || - request.Context is null || - request.Context.Session is null) - { - continue; - } - - if (!_configuration.IncludeOptionsRequests && - string.Equals(request.Context.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) - { - Logger.LogDebug("Skipping OPTIONS request {url}...", request.Context.Session.HttpClient.Request.RequestUri); - continue; - } - - var methodAndUrlString = request.MessageLines.First(); - Logger.LogDebug("Adding request {methodAndUrl}...", methodAndUrlString); - - var methodAndUrl = methodAndUrlString.Split(' '); - httpFile.Requests.Add(new HttpFileRequest - { - Method = methodAndUrl[0], - Url = methodAndUrl[1], - Body = request.Context.Session.HttpClient.Request.HasBody ? await request.Context.Session.GetRequestBodyAsString() : null, - Headers = request.Context.Session.HttpClient.Request.Headers - .Select(h => new HttpFileRequestHeader { Name = h.Name, Value = h.Value }) - .ToList() - }); - } - - return httpFile; - } - - private void DeduplicateRequests(HttpFile httpFile) - { - Logger.LogDebug("Deduplicating requests..."); - - // remove duplicate requests - // if the request doesn't have a body, dedupe on method + URL - // if it has a body, dedupe on method + URL + body - var uniqueRequests = new List(); - foreach (var request in httpFile.Requests) - { - Logger.LogDebug(" Checking request {method} {url}...", request.Method, request.Url); - - var existingRequest = uniqueRequests.FirstOrDefault(r => - { - if (r.Method != request.Method || r.Url != request.Url) - { - return false; - } - - if (r.Body is null && request.Body is null) - { - return true; - } - - if (r.Body is not null && request.Body is not null) - { - return r.Body == request.Body; - } - - return false; - }); - - if (existingRequest is null) - { - Logger.LogDebug(" Keeping request {method} {url}...", request.Method, request.Url); - uniqueRequests.Add(request); - } - else - { - Logger.LogDebug(" Skipping duplicate request {method} {url}...", request.Method, request.Url); - } - } - - httpFile.Requests = uniqueRequests; - } - - private void ExtractVariables(HttpFile httpFile) - { - Logger.LogDebug("Extracting variables..."); - - foreach (var request in httpFile.Requests) - { - Logger.LogDebug(" Processing request {method} {url}...", request.Method, request.Url); - - foreach (var headerName in headersToExtract) - { - Logger.LogDebug(" Extracting header {headerName}...", headerName); - - var headers = request.Headers.Where(h => h.Name.Contains(headerName, StringComparison.OrdinalIgnoreCase)); - if (headers is not null) - { - Logger.LogDebug(" Found {numHeaders} matching headers...", headers.Count()); - - foreach (var header in headers) - { - var variableName = GetVariableName(request, headerName); - Logger.LogDebug(" Extracting variable {variableName}...", variableName); - httpFile.Variables[variableName] = header.Value; - header.Value = $"{{{{{variableName}}}}}"; - } - } - } - - var url = new Uri(request.Url); - var query = HttpUtility.ParseQueryString(url.Query); - if (query.Count > 0) - { - Logger.LogDebug(" Processing query parameters..."); - - foreach (var queryParameterName in queryParametersToExtract) - { - Logger.LogDebug(" Extracting query parameter {queryParameterName}...", queryParameterName); - - var queryParams = query.AllKeys.Where(k => k is not null && k.Contains(queryParameterName, StringComparison.OrdinalIgnoreCase)); - if (queryParams is not null) - { - Logger.LogDebug(" Found {numQueryParams} matching query parameters...", queryParams.Count()); - - foreach (var queryParam in queryParams) - { - var variableName = GetVariableName(request, queryParam!); - Logger.LogDebug(" Extracting variable {variableName}...", variableName); - httpFile.Variables[variableName] = queryParam!; - query[queryParam] = $"{{{{{variableName}}}}}"; - } - } - } - request.Url = $"{url.GetLeftPart(UriPartial.Path)}?{query}" - .Replace("%7b", "{") - .Replace("%7d", "}"); - Logger.LogDebug(" Updated URL to {url}...", request.Url); - } - else - { - Logger.LogDebug(" No query parameters to process..."); - } - } - } - - private static string GetVariableName(HttpFileRequest request, string variableName) - { - var url = new Uri(request.Url); - return $"{url.Host.Replace(".", "_").Replace("-", "_")}_{variableName.Replace("-", "_")}"; - } +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text; +using System.Web; +using Microsoft.DevProxy.Abstractions; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +internal class HttpFile +{ + public Dictionary Variables { get; set; } = []; + public List Requests { get; set; } = []; + + public string Serialize() + { + var sb = new StringBuilder(); + + foreach (var variable in Variables) + { + sb.AppendLine($"@{variable.Key} = {variable.Value}"); + } + + foreach (var request in Requests) + { + sb.AppendLine(); + sb.AppendLine("###"); + sb.AppendLine(); + sb.AppendLine($"# @name {GetRequestName(request)}"); + sb.AppendLine(); + + sb.AppendLine($"{request.Method} {request.Url}"); + + foreach (var header in request.Headers) + { + sb.AppendLine($"{header.Name}: {header.Value}"); + } + + if (!string.IsNullOrEmpty(request.Body)) + { + sb.AppendLine(); + sb.AppendLine(request.Body); + } + } + + return sb.ToString(); + } + + private static string GetRequestName(HttpFileRequest request) + { + var url = new Uri(request.Url); + return $"{request.Method.ToLower()}{url.Segments.Last().Replace("/", "").ToPascalCase()}"; + } +} + +internal class HttpFileRequest +{ + public string Method { get; set; } = string.Empty; + public string Url { get; set; } = string.Empty; + public string? Body { get; set; } + public List Headers { get; set; } = []; +} + +internal class HttpFileRequestHeader +{ + public string Name { get; set; } = string.Empty; + public string Value { get; set; } = string.Empty; +} + +public class HttpFileGeneratorPluginReport : List +{ + public HttpFileGeneratorPluginReport() : base() { } + + public HttpFileGeneratorPluginReport(IEnumerable collection) : base(collection) { } +} + +internal class HttpFileGeneratorPluginConfiguration +{ + public bool IncludeOptionsRequests { get; set; } = false; +} + +public class HttpFileGeneratorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(HttpFileGeneratorPlugin); + public static readonly string GeneratedHttpFilesKey = "GeneratedHttpFiles"; + private HttpFileGeneratorPluginConfiguration _configuration = new(); + private readonly string[] headersToExtract = ["authorization", "key"]; + private readonly string[] queryParametersToExtract = ["key"]; + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) + { + Logger.LogInformation("Creating HTTP file from recorded requests..."); + + if (!e.RequestLogs.Any()) + { + Logger.LogDebug("No requests to process"); + return; + } + + var httpFile = await GetHttpRequestsAsync(e.RequestLogs); + DeduplicateRequests(httpFile); + ExtractVariables(httpFile); + + var fileName = $"requests_{DateTime.Now:yyyyMMddHHmmss}.http"; + Logger.LogDebug("Writing HTTP file to {fileName}...", fileName); + File.WriteAllText(fileName, httpFile.Serialize()); + Logger.LogInformation("Created HTTP file {fileName}", fileName); + + var generatedHttpFiles = new[] { fileName }; + StoreReport(new HttpFileGeneratorPluginReport(generatedHttpFiles), e); + + // store the generated HTTP files in the global data + // for use by other plugins + e.GlobalData[GeneratedHttpFilesKey] = generatedHttpFiles; + } + + private async Task GetHttpRequestsAsync(IEnumerable requestLogs) + { + var httpFile = new HttpFile(); + + foreach (var request in requestLogs) + { + if (request.MessageType != MessageType.InterceptedResponse || + request.Context is null || + request.Context.Session is null) + { + continue; + } + + if (!_configuration.IncludeOptionsRequests && + string.Equals(request.Context.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogDebug("Skipping OPTIONS request {url}...", request.Context.Session.HttpClient.Request.RequestUri); + continue; + } + + var methodAndUrlString = request.Message; + Logger.LogDebug("Adding request {methodAndUrl}...", methodAndUrlString); + + var methodAndUrl = methodAndUrlString.Split(' '); + httpFile.Requests.Add(new HttpFileRequest + { + Method = methodAndUrl[0], + Url = methodAndUrl[1], + Body = request.Context.Session.HttpClient.Request.HasBody ? await request.Context.Session.GetRequestBodyAsString() : null, + Headers = request.Context.Session.HttpClient.Request.Headers + .Select(h => new HttpFileRequestHeader { Name = h.Name, Value = h.Value }) + .ToList() + }); + } + + return httpFile; + } + + private void DeduplicateRequests(HttpFile httpFile) + { + Logger.LogDebug("Deduplicating requests..."); + + // remove duplicate requests + // if the request doesn't have a body, dedupe on method + URL + // if it has a body, dedupe on method + URL + body + var uniqueRequests = new List(); + foreach (var request in httpFile.Requests) + { + Logger.LogDebug(" Checking request {method} {url}...", request.Method, request.Url); + + var existingRequest = uniqueRequests.FirstOrDefault(r => + { + if (r.Method != request.Method || r.Url != request.Url) + { + return false; + } + + if (r.Body is null && request.Body is null) + { + return true; + } + + if (r.Body is not null && request.Body is not null) + { + return r.Body == request.Body; + } + + return false; + }); + + if (existingRequest is null) + { + Logger.LogDebug(" Keeping request {method} {url}...", request.Method, request.Url); + uniqueRequests.Add(request); + } + else + { + Logger.LogDebug(" Skipping duplicate request {method} {url}...", request.Method, request.Url); + } + } + + httpFile.Requests = uniqueRequests; + } + + private void ExtractVariables(HttpFile httpFile) + { + Logger.LogDebug("Extracting variables..."); + + foreach (var request in httpFile.Requests) + { + Logger.LogDebug(" Processing request {method} {url}...", request.Method, request.Url); + + foreach (var headerName in headersToExtract) + { + Logger.LogDebug(" Extracting header {headerName}...", headerName); + + var headers = request.Headers.Where(h => h.Name.Contains(headerName, StringComparison.OrdinalIgnoreCase)); + if (headers is not null) + { + Logger.LogDebug(" Found {numHeaders} matching headers...", headers.Count()); + + foreach (var header in headers) + { + var variableName = GetVariableName(request, headerName); + Logger.LogDebug(" Extracting variable {variableName}...", variableName); + httpFile.Variables[variableName] = header.Value; + header.Value = $"{{{{{variableName}}}}}"; + } + } + } + + var url = new Uri(request.Url); + var query = HttpUtility.ParseQueryString(url.Query); + if (query.Count > 0) + { + Logger.LogDebug(" Processing query parameters..."); + + foreach (var queryParameterName in queryParametersToExtract) + { + Logger.LogDebug(" Extracting query parameter {queryParameterName}...", queryParameterName); + + var queryParams = query.AllKeys.Where(k => k is not null && k.Contains(queryParameterName, StringComparison.OrdinalIgnoreCase)); + if (queryParams is not null) + { + Logger.LogDebug(" Found {numQueryParams} matching query parameters...", queryParams.Count()); + + foreach (var queryParam in queryParams) + { + var variableName = GetVariableName(request, queryParam!); + Logger.LogDebug(" Extracting variable {variableName}...", variableName); + httpFile.Variables[variableName] = queryParam!; + query[queryParam] = $"{{{{{variableName}}}}}"; + } + } + } + request.Url = $"{url.GetLeftPart(UriPartial.Path)}?{query}" + .Replace("%7b", "{") + .Replace("%7d", "}"); + Logger.LogDebug(" Updated URL to {url}...", request.Url); + } + else + { + Logger.LogDebug(" No query parameters to process..."); + } + } + } + + private static string GetVariableName(HttpFileRequest request, string variableName) + { + var url = new Uri(request.Url); + return $"{url.Host.Replace(".", "_").Replace("-", "_")}_{variableName.Replace("-", "_")}"; + } } \ No newline at end of file diff --git a/dev-proxy-plugins/RequestLogs/MinimalPermissionsPlugin.cs b/dev-proxy-plugins/RequestLogs/MinimalPermissionsPlugin.cs index 51f073fe..e5e57970 100644 --- a/dev-proxy-plugins/RequestLogs/MinimalPermissionsPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/MinimalPermissionsPlugin.cs @@ -63,7 +63,7 @@ private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) var interceptedRequests = e.RequestLogs .Where(l => l.MessageType == MessageType.InterceptedRequest && - !l.MessageLines.First().StartsWith("OPTIONS") && + !l.Message.StartsWith("OPTIONS") && l.Context?.Session is not null && l.Context.Session.HttpClient.Request.Headers.Any(h => h.Name.Equals("authorization", StringComparison.OrdinalIgnoreCase)) ); @@ -87,7 +87,7 @@ private async Task AfterRecordingStopAsync(object sender, RecordingArgs e) var errors = new List(); var results = new List(); var unmatchedRequests = new List( - unmatchedApiSpecRequests.Select(r => r.MessageLines.First()) + unmatchedApiSpecRequests.Select(r => r.Message) ); foreach (var (apiSpec, requests) in requestsByApiSpec) @@ -214,7 +214,7 @@ private Dictionary LoadApiSpecs(string apiSpecsFolderPa var requestsByApiSpec = new Dictionary>(); foreach (var request in interceptedRequests) { - var url = request.MessageLines.First().Split(' ')[1]; + var url = request.Message.Split(' ')[1]; Logger.LogDebug("Matching request {requestUrl} to API specs...", url); var matchingKey = apiSpecsByUrl.Keys.FirstOrDefault(url.StartsWith); diff --git a/dev-proxy-plugins/RequestLogs/MockGeneratorPlugin.cs b/dev-proxy-plugins/RequestLogs/MockGeneratorPlugin.cs index b3ee8e6a..77c12464 100644 --- a/dev-proxy-plugins/RequestLogs/MockGeneratorPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/MockGeneratorPlugin.cs @@ -1,160 +1,160 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Text.Json; -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Plugins.Mocks; -using Titanium.Web.Proxy.EventArguments; -using Microsoft.Extensions.Logging; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -public class MockGeneratorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - public override string Name => nameof(MockGeneratorPlugin); - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) - { - Logger.LogInformation("Creating mocks from recorded requests..."); - - if (!e.RequestLogs.Any()) - { - Logger.LogDebug("No requests to process"); - return; - } - - var methodAndUrlComparer = new MethodAndUrlComparer(); - var mocks = new List(); - - foreach (var request in e.RequestLogs) - { - if (request.MessageType != MessageType.InterceptedResponse || - request.Context is null || - request.Context.Session is null) - { - continue; - } - - var methodAndUrlString = request.MessageLines.First(); - Logger.LogDebug("Processing request {methodAndUrlString}...", methodAndUrlString); - - var (method, url) = GetMethodAndUrl(methodAndUrlString); - var response = request.Context.Session.HttpClient.Response; - - var newHeaders = new List(); - newHeaders.AddRange(response.Headers.Select(h => new MockResponseHeader(h.Name, h.Value))); - var mock = new MockResponse - { - Request = new() - { - Method = method, - Url = url, - }, - Response = new() - { - StatusCode = response.StatusCode, - Headers = newHeaders, - Body = await GetResponseBodyAsync(request.Context.Session) - } - }; - // skip mock if it's 200 but has no body - if (mock.Response.StatusCode == 200 && mock.Response.Body is null) - { - Logger.LogDebug("Skipping mock with 200 response code and no body"); - continue; - } - - mocks.Add(mock); - Logger.LogDebug("Added mock for {method} {url}", mock.Request.Method, mock.Request.Url); - } - - Logger.LogDebug("Sorting mocks..."); - // sort mocks descending by url length so that most specific mocks are first - mocks.Sort((a, b) => b.Request!.Url.CompareTo(a.Request!.Url)); - - var mocksFile = new MockResponseConfiguration { Mocks = mocks }; - - Logger.LogDebug("Serializing mocks..."); - var mocksFileJson = JsonSerializer.Serialize(mocksFile, ProxyUtils.JsonSerializerOptions); - var fileName = $"mocks-{DateTime.Now:yyyyMMddHHmmss}.json"; - - Logger.LogDebug("Writing mocks to {fileName}...", fileName); - File.WriteAllText(fileName, mocksFileJson); - - Logger.LogInformation("Created mock file {fileName} with {mocksCount} mocks", fileName, mocks.Count); - - StoreReport(fileName, e); - } - - /// - /// Returns the body of the response. For binary responses, - /// saves the binary response as a file on disk and returns @filename - /// - /// Request session - /// Response body or @filename for binary responses - private async Task GetResponseBodyAsync(SessionEventArgs session) - { - Logger.LogDebug("Getting response body..."); - - var response = session.HttpClient.Response; - if (response.ContentType is null || !response.HasBody) - { - Logger.LogDebug("Response has no content-type set or has no body. Skipping"); - return null; - } - - if (response.ContentType.Contains("application/json")) - { - Logger.LogDebug("Response is JSON"); - - try - { - Logger.LogDebug("Reading response body as string..."); - var body = response.IsBodyRead ? response.BodyString : await session.GetResponseBodyAsString(); - Logger.LogDebug("Body: {body}", body); - Logger.LogDebug("Deserializing response body..."); - return JsonSerializer.Deserialize(body, ProxyUtils.JsonSerializerOptions); - } - catch (Exception ex) - { - Logger.LogError(ex, "Error reading response body"); - return null; - } - } - - Logger.LogDebug("Response is binary"); - // assume body is binary - try - { - var filename = $"response-{Guid.NewGuid()}.bin"; - Logger.LogDebug("Reading response body as bytes..."); - var body = await session.GetResponseBody(); - Logger.LogDebug("Writing response body to {filename}...", filename); - File.WriteAllBytes(filename, body); - return $"@{filename}"; - } - catch (Exception ex) - { - Logger.LogError(ex, "Error reading response body"); - return null; - } - } - - private static (string method, string url) GetMethodAndUrl(string message) - { - var info = message.Split(" "); - if (info.Length > 2) - { - info = [info[0], string.Join(" ", info.Skip(1))]; - } - return (info[0], info[1]); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Plugins.Mocks; +using Titanium.Web.Proxy.EventArguments; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +public class MockGeneratorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + public override string Name => nameof(MockGeneratorPlugin); + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) + { + Logger.LogInformation("Creating mocks from recorded requests..."); + + if (!e.RequestLogs.Any()) + { + Logger.LogDebug("No requests to process"); + return; + } + + var methodAndUrlComparer = new MethodAndUrlComparer(); + var mocks = new List(); + + foreach (var request in e.RequestLogs) + { + if (request.MessageType != MessageType.InterceptedResponse || + request.Context is null || + request.Context.Session is null) + { + continue; + } + + var methodAndUrlString = request.Message; + Logger.LogDebug("Processing request {methodAndUrlString}...", methodAndUrlString); + + var (method, url) = GetMethodAndUrl(methodAndUrlString); + var response = request.Context.Session.HttpClient.Response; + + var newHeaders = new List(); + newHeaders.AddRange(response.Headers.Select(h => new MockResponseHeader(h.Name, h.Value))); + var mock = new MockResponse + { + Request = new() + { + Method = method, + Url = url, + }, + Response = new() + { + StatusCode = response.StatusCode, + Headers = newHeaders, + Body = await GetResponseBodyAsync(request.Context.Session) + } + }; + // skip mock if it's 200 but has no body + if (mock.Response.StatusCode == 200 && mock.Response.Body is null) + { + Logger.LogDebug("Skipping mock with 200 response code and no body"); + continue; + } + + mocks.Add(mock); + Logger.LogDebug("Added mock for {method} {url}", mock.Request.Method, mock.Request.Url); + } + + Logger.LogDebug("Sorting mocks..."); + // sort mocks descending by url length so that most specific mocks are first + mocks.Sort((a, b) => b.Request!.Url.CompareTo(a.Request!.Url)); + + var mocksFile = new MockResponseConfiguration { Mocks = mocks }; + + Logger.LogDebug("Serializing mocks..."); + var mocksFileJson = JsonSerializer.Serialize(mocksFile, ProxyUtils.JsonSerializerOptions); + var fileName = $"mocks-{DateTime.Now:yyyyMMddHHmmss}.json"; + + Logger.LogDebug("Writing mocks to {fileName}...", fileName); + File.WriteAllText(fileName, mocksFileJson); + + Logger.LogInformation("Created mock file {fileName} with {mocksCount} mocks", fileName, mocks.Count); + + StoreReport(fileName, e); + } + + /// + /// Returns the body of the response. For binary responses, + /// saves the binary response as a file on disk and returns @filename + /// + /// Request session + /// Response body or @filename for binary responses + private async Task GetResponseBodyAsync(SessionEventArgs session) + { + Logger.LogDebug("Getting response body..."); + + var response = session.HttpClient.Response; + if (response.ContentType is null || !response.HasBody) + { + Logger.LogDebug("Response has no content-type set or has no body. Skipping"); + return null; + } + + if (response.ContentType.Contains("application/json")) + { + Logger.LogDebug("Response is JSON"); + + try + { + Logger.LogDebug("Reading response body as string..."); + var body = response.IsBodyRead ? response.BodyString : await session.GetResponseBodyAsString(); + Logger.LogDebug("Body: {body}", body); + Logger.LogDebug("Deserializing response body..."); + return JsonSerializer.Deserialize(body, ProxyUtils.JsonSerializerOptions); + } + catch (Exception ex) + { + Logger.LogError(ex, "Error reading response body"); + return null; + } + } + + Logger.LogDebug("Response is binary"); + // assume body is binary + try + { + var filename = $"response-{Guid.NewGuid()}.bin"; + Logger.LogDebug("Reading response body as bytes..."); + var body = await session.GetResponseBody(); + Logger.LogDebug("Writing response body to {filename}...", filename); + File.WriteAllBytes(filename, body); + return $"@{filename}"; + } + catch (Exception ex) + { + Logger.LogError(ex, "Error reading response body"); + return null; + } + } + + private static (string method, string url) GetMethodAndUrl(string message) + { + var info = message.Split(" "); + if (info.Length > 2) + { + info = [info[0], string.Join(" ", info.Skip(1))]; + } + return (info[0], info[1]); + } +} diff --git a/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs b/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs index 1a4e089a..508eb2b8 100644 --- a/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs +++ b/dev-proxy-plugins/RequestLogs/OpenApiSpecGeneratorPlugin.cs @@ -1,962 +1,962 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.Extensions.Configuration; -using Microsoft.DevProxy.Abstractions; -using Titanium.Web.Proxy.EventArguments; -using Microsoft.OpenApi.Models; -using Microsoft.OpenApi.Extensions; -using System.Text.Json; -using Microsoft.OpenApi.Interfaces; -using Microsoft.OpenApi.Writers; -using Microsoft.OpenApi; -using Titanium.Web.Proxy.Http; -using System.Web; -using System.Collections.Specialized; -using Microsoft.Extensions.Logging; -using Microsoft.DevProxy.Abstractions.LanguageModel; - -namespace Microsoft.DevProxy.Plugins.RequestLogs; - -public class OpenApiSpecGeneratorPluginReportItem -{ - public required string ServerUrl { get; init; } - public required string FileName { get; init; } -} - -public class OpenApiSpecGeneratorPluginReport : List -{ - public OpenApiSpecGeneratorPluginReport() : base() { } - - public OpenApiSpecGeneratorPluginReport(IEnumerable collection) : base(collection) { } -} - -class GeneratedByOpenApiExtension : IOpenApiExtension -{ - public void Write(IOpenApiWriter writer, OpenApiSpecVersion specVersion) - { - writer.WriteStartObject(); - writer.WriteProperty("toolName", "Dev Proxy"); - writer.WriteProperty("toolVersion", ProxyUtils.ProductVersion); - writer.WriteEndObject(); - } -} - -internal class OpenApiSpecGeneratorPluginConfiguration -{ - public bool IncludeOptionsRequests { get; set; } = false; -} - -public class OpenApiSpecGeneratorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) -{ - // from: https://github.com/jonluca/har-to-openapi/blob/0d44409162c0a127cdaccd60b0a270ecd361b829/src/utils/headers.ts - private static readonly string[] standardHeaders = - [ - ":authority", - ":method", - ":path", - ":scheme", - ":status", - "a-im", - "accept", - "accept-additions", - "accept-ch", - "accept-ch-lifetime", - "accept-charset", - "accept-datetime", - "accept-encoding", - "accept-features", - "accept-language", - "accept-patch", - "accept-post", - "accept-ranges", - "access-control-allow-credentials", - "access-control-allow-headers", - "access-control-allow-methods", - "access-control-allow-origin", - "access-control-expose-headers", - "access-control-max-age", - "access-control-request-headers", - "access-control-request-method", - "age", - "allow", - "alpn", - "alt-svc", - "alternate-protocol", - "alternates", - "amp-access-control-allow-source-origin", - "apply-to-redirect-ref", - "authentication-info", - "authorization", - "c-ext", - "c-man", - "c-opt", - "c-pep", - "c-pep-info", - "cache-control", - "ch", - "connection", - "content-base", - "content-disposition", - "content-dpr", - "content-encoding", - "content-id", - "content-language", - "content-length", - "content-location", - "content-md5", - "content-range", - "content-script-type", - "content-security-policy", - "content-security-policy-report-only", - "content-style-type", - "content-type", - "content-version", - "cookie", - "cookie2", - "cross-origin-resource-policy", - "dasl", - "date", - "dav", - "default-style", - "delta-base", - "depth", - "derived-from", - "destination", - "differential-id", - "digest", - "dnt", - "dpr", - "encryption", - "encryption-key", - "etag", - "expect", - "expect-ct", - "expires", - "ext", - "forwarded", - "from", - "front-end-https", - "getprofile", - "host", - "http2-settings", - "if", - "if-match", - "if-modified-since", - "if-none-match", - "if-range", - "if-schedule-tag-match", - "if-unmodified-since", - "im", - "keep-alive", - "key", - "label", - "last-event-id", - "last-modified", - "link", - "link-template", - "location", - "lock-token", - "man", - "max-forwards", - "md", - "meter", - "mime-version", - "negotiate", - "nice", - "opt", - "ordering-type", - "origin", - "origin-trial", - "overwrite", - "p3p", - "pep", - "pep-info", - "pics-label", - "poe", - "poe-links", - "position", - "pragma", - "prefer", - "preference-applied", - "profileobject", - "protocol", - "protocol-info", - "protocol-query", - "protocol-request", - "proxy-authenticate", - "proxy-authentication-info", - "proxy-authorization", - "proxy-connection", - "proxy-features", - "proxy-instruction", - "public", - "range", - "redirect-ref", - "referer", - "referrer-policy", - "report-to", - "retry-after", - "rw", - "safe", - "save-data", - "schedule-reply", - "schedule-tag", - "sec-ch-ua", - "sec-ch-ua-mobile", - "sec-ch-ua-platform", - "sec-fetch-dest", - "sec-fetch-mode", - "sec-fetch-site", - "sec-fetch-user", - "sec-websocket-accept", - "sec-websocket-extensions", - "sec-websocket-key", - "sec-websocket-protocol", - "sec-websocket-version", - "security-scheme", - "server", - "server-timing", - "set-cookie", - "set-cookie2", - "setprofile", - "slug", - "soapaction", - "status-uri", - "strict-transport-security", - "sunset", - "surrogate-capability", - "surrogate-control", - "tcn", - "te", - "timeout", - "timing-allow-origin", - "tk", - "trailer", - "transfer-encoding", - "upgrade", - "upgrade-insecure-requests", - "uri", - "user-agent", - "variant-vary", - "vary", - "via", - "want-digest", - "warning", - "www-authenticate", - "x-att-deviceid", - "x-csrf-token", - "x-forwarded-for", - "x-forwarded-host", - "x-forwarded-proto", - "x-frame-options", - "x-frontend", - "x-http-method-override", - "x-powered-by", - "x-request-id", - "x-requested-with", - "x-uidh", - "x-wap-profile", - "x-xss-protection" - ]; - private static readonly string[] authHeaders = - [ - "access-token", - "api-key", - "auth-token", - "authorization", - "authorization-token", - "cookie", - "key", - "token", - "x-access-token", - "x-access-token", - "x-api-key", - "x-auth", - "x-auth-token", - "x-csrf-token", - "secret", - "x-secret", - "access-key", - "api-key", - "apikey" - ]; - - public override string Name => nameof(OpenApiSpecGeneratorPlugin); - private readonly OpenApiSpecGeneratorPluginConfiguration _configuration = new(); - public static readonly string GeneratedOpenApiSpecsKey = "GeneratedOpenApiSpecs"; - - public override async Task RegisterAsync() - { - await base.RegisterAsync(); - - ConfigSection?.Bind(_configuration); - - PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; - } - - private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) - { - Logger.LogInformation("Creating OpenAPI spec from recorded requests..."); - - if (!e.RequestLogs.Any()) - { - Logger.LogDebug("No requests to process"); - return; - } - - var openApiDocs = new List(); - - foreach (var request in e.RequestLogs) - { - if (request.MessageType != MessageType.InterceptedResponse || - request.Context is null || - request.Context.Session is null) - { - continue; - } - - if (!_configuration.IncludeOptionsRequests && - string.Equals(request.Context.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) - { - Logger.LogDebug("Skipping OPTIONS request {url}...", request.Context.Session.HttpClient.Request.RequestUri); - continue; - } - - var methodAndUrlString = request.MessageLines.First(); - Logger.LogDebug("Processing request {methodAndUrlString}...", methodAndUrlString); - - try - { - var pathItem = GetOpenApiPathItem(request.Context.Session); - var parametrizedPath = ParametrizePath(pathItem, request.Context.Session.HttpClient.Request.RequestUri); - var operationInfo = pathItem.Operations.First(); - operationInfo.Value.OperationId = await GetOperationIdAsync( - operationInfo.Key.ToString(), - request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority), - parametrizedPath - ); - operationInfo.Value.Description = await GetOperationDescriptionAsync( - operationInfo.Key.ToString(), - request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority), - parametrizedPath - ); - AddOrMergePathItem(openApiDocs, pathItem, request.Context.Session.HttpClient.Request.RequestUri, parametrizedPath); - } - catch (Exception ex) - { - Logger.LogError(ex, "Error processing request {methodAndUrl}", methodAndUrlString); - } - } - - Logger.LogDebug("Serializing OpenAPI docs..."); - var generatedOpenApiSpecs = new Dictionary(); - foreach (var openApiDoc in openApiDocs) - { - var server = openApiDoc.Servers.First(); - var fileName = GetFileNameFromServerUrl(server.Url); - var docString = openApiDoc.SerializeAsJson(OpenApiSpecVersion.OpenApi3_0); - - Logger.LogDebug(" Writing OpenAPI spec to {fileName}...", fileName); - File.WriteAllText(fileName, docString); - - generatedOpenApiSpecs.Add(server.Url, fileName); - - Logger.LogInformation("Created OpenAPI spec file {fileName}", fileName); - } - - StoreReport(new OpenApiSpecGeneratorPluginReport( - generatedOpenApiSpecs - .Select(kvp => new OpenApiSpecGeneratorPluginReportItem - { - ServerUrl = kvp.Key, - FileName = kvp.Value - })), e); - - // store the generated OpenAPI specs in the global data - // for use by other plugins - e.GlobalData[GeneratedOpenApiSpecsKey] = generatedOpenApiSpecs; - } - - /** - * Replaces segments in the request URI, that match predefined patters, - * with parameters and adds them to the OpenAPI PathItem. - * @param pathItem The OpenAPI PathItem to parametrize. - * @param requestUri The request URI. - * @returns The parametrized server-relative URL - */ - private static string ParametrizePath(OpenApiPathItem pathItem, Uri requestUri) - { - var segments = requestUri.Segments; - var previousSegment = "item"; - - for (var i = 0; i < segments.Length; i++) - { - var segment = requestUri.Segments[i].Trim('/'); - if (string.IsNullOrEmpty(segment)) - { - continue; - } - - if (IsParametrizable(segment)) - { - var parameterName = $"{previousSegment}-id"; - segments[i] = $"{{{parameterName}}}{(requestUri.Segments[i].EndsWith('/') ? "/" : "")}"; - - pathItem.Parameters.Add(new OpenApiParameter - { - Name = parameterName, - In = ParameterLocation.Path, - Required = true, - Schema = new OpenApiSchema { Type = "string" } - }); - } - else - { - previousSegment = segment; - } - } - - return string.Join(string.Empty, segments); - } - - private static bool IsParametrizable(string segment) - { - return Guid.TryParse(segment.Trim('/'), out _) || - int.TryParse(segment.Trim('/'), out _); - } - - private static string GetLastNonTokenSegment(string[] segments) - { - for (var i = segments.Length - 1; i >= 0; i--) - { - var segment = segments[i].Trim('/'); - if (string.IsNullOrEmpty(segment)) - { - continue; - } - - if (!IsParametrizable(segment)) - { - return segment; - } - } - - return "item"; - } - - private async Task GetOperationIdAsync(string method, string serverUrl, string parametrizedPath) - { - var prompt = $"For the specified request, generate an operation ID, compatible with an OpenAPI spec. Respond with just the ID in plain-text format. For example, for request such as `GET https://api.contoso.com/books/{{books-id}}` you return `getBookById`. For a request like `GET https://api.contoso.com/books/{{books-id}}/authors` you return `getAuthorsForBookById`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}"; - ILanguageModelCompletionResponse? id = null; - if (await Context.LanguageModelClient.IsEnabledAsync()) { - id = await Context.LanguageModelClient.GenerateCompletionAsync(prompt); - } - return id?.Response ?? $"{method}{parametrizedPath.Replace('/', '.')}"; - } - - private async Task GetOperationDescriptionAsync(string method, string serverUrl, string parametrizedPath) - { - var prompt = $"You're an expert in OpenAPI. You help developers build great OpenAPI specs for use with LLMs. For the specified request, generate a one-sentence description. Respond with just the description. For example, for a request such as `GET https://api.contoso.com/books/{{books-id}}` you return `Get a book by ID`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}"; - ILanguageModelCompletionResponse? description = null; - if (await Context.LanguageModelClient.IsEnabledAsync()) { - description = await Context.LanguageModelClient.GenerateCompletionAsync(prompt); - } - return description?.Response ?? $"{method} {parametrizedPath}"; - } - - /** - * Creates an OpenAPI PathItem from an intercepted request and response pair. - * @param session The intercepted session. - */ - private OpenApiPathItem GetOpenApiPathItem(SessionEventArgs session) - { - var request = session.HttpClient.Request; - var response = session.HttpClient.Response; - - var resource = GetLastNonTokenSegment(request.RequestUri.Segments); - var path = new OpenApiPathItem(); - - var method = request.Method?.ToUpperInvariant() switch - { - "DELETE" => OperationType.Delete, - "GET" => OperationType.Get, - "HEAD" => OperationType.Head, - "OPTIONS" => OperationType.Options, - "PATCH" => OperationType.Patch, - "POST" => OperationType.Post, - "PUT" => OperationType.Put, - "TRACE" => OperationType.Trace, - _ => throw new NotSupportedException($"Method {request.Method} is not supported") - }; - var operation = new OpenApiOperation - { - // will be replaced later after the path has been parametrized - Description = $"{method} {resource}", - // will be replaced later after the path has been parametrized - OperationId = $"{method}.{resource}" - }; - SetParametersFromQueryString(operation, HttpUtility.ParseQueryString(request.RequestUri.Query)); - SetParametersFromRequestHeaders(operation, request.Headers); - SetRequestBody(operation, request); - SetResponseFromSession(operation, response); - - path.Operations.Add(method, operation); - - return path; - } - - private void SetRequestBody(OpenApiOperation operation, Request request) - { - if (!request.HasBody) - { - Logger.LogDebug(" Request has no body"); - return; - } - - if (request.ContentType is null) - { - Logger.LogDebug(" Request has no content type"); - return; - } - - Logger.LogDebug(" Processing request body..."); - operation.RequestBody = new OpenApiRequestBody - { - Content = new Dictionary - { - { - GetMediaType(request.ContentType), - new OpenApiMediaType - { - Schema = GetSchemaFromBody(GetMediaType(request.ContentType), request.BodyString) - } - } - } - }; - } - - private void SetParametersFromRequestHeaders(OpenApiOperation operation, HeaderCollection headers) - { - if (headers is null || - !headers.Any()) - { - Logger.LogDebug(" Request has no headers"); - return; - } - - Logger.LogDebug(" Processing request headers..."); - foreach (var header in headers) - { - var lowerCaseHeaderName = header.Name.ToLowerInvariant(); - if (standardHeaders.Contains(lowerCaseHeaderName)) - { - Logger.LogDebug(" Skipping standard header {headerName}", header.Name); - continue; - } - - if (authHeaders.Contains(lowerCaseHeaderName)) - { - Logger.LogDebug(" Skipping auth header {headerName}", header.Name); - continue; - } - - operation.Parameters.Add(new OpenApiParameter - { - Name = header.Name, - In = ParameterLocation.Header, - Required = false, - Schema = new OpenApiSchema { Type = "string" } - }); - Logger.LogDebug(" Added header {headerName}", header.Name); - } - } - - private void SetParametersFromQueryString(OpenApiOperation operation, NameValueCollection queryParams) - { - if (queryParams.AllKeys is null || - queryParams.AllKeys.Length == 0) - { - Logger.LogDebug(" Request has no query string parameters"); - return; - } - - Logger.LogDebug(" Processing query string parameters..."); - var dictionary = (queryParams.AllKeys as string[]).ToDictionary(k => k, k => queryParams[k] as object); - - foreach (var parameter in dictionary) - { - operation.Parameters.Add(new OpenApiParameter - { - Name = parameter.Key, - In = ParameterLocation.Query, - Required = false, - Schema = new OpenApiSchema { Type = "string" } - }); - Logger.LogDebug(" Added query string parameter {parameterKey}", parameter.Key); - } - } - - private void SetResponseFromSession(OpenApiOperation operation, Response response) - { - if (response is null) - { - Logger.LogDebug(" No response to process"); - return; - } - - Logger.LogDebug(" Processing response..."); - - var openApiResponse = new OpenApiResponse - { - Description = response.StatusDescription - }; - var responseCode = response.StatusCode.ToString(); - if (response.HasBody) - { - Logger.LogDebug(" Response has body"); - - openApiResponse.Content.Add(GetMediaType(response.ContentType), new OpenApiMediaType - { - Schema = GetSchemaFromBody(GetMediaType(response.ContentType), response.BodyString) - }); - } - else - { - Logger.LogDebug(" Response doesn't have body"); - } - - if (response.Headers is not null && response.Headers.Any()) - { - Logger.LogDebug(" Response has headers"); - - foreach (var header in response.Headers) - { - var lowerCaseHeaderName = header.Name.ToLowerInvariant(); - if (standardHeaders.Contains(lowerCaseHeaderName)) - { - Logger.LogDebug(" Skipping standard header {headerName}", header.Name); - continue; - } - - if (authHeaders.Contains(lowerCaseHeaderName)) - { - Logger.LogDebug(" Skipping auth header {headerName}", header.Name); - continue; - } - - if (openApiResponse.Headers.ContainsKey(header.Name)) - { - Logger.LogDebug(" Header {headerName} already exists in response", header.Name); - continue; - } - - openApiResponse.Headers.Add(header.Name, new OpenApiHeader - { - Schema = new OpenApiSchema { Type = "string" } - }); - Logger.LogDebug(" Added header {headerName}", header.Name); - } - } - else - { - Logger.LogDebug(" Response doesn't have headers"); - } - - operation.Responses.Add(responseCode, openApiResponse); - } - - private static string GetMediaType(string? contentType) - { - if (string.IsNullOrEmpty(contentType)) - { - return contentType ?? ""; - } - - var mediaType = contentType.Split(';').First().Trim(); - return mediaType; - } - - private OpenApiSchema? GetSchemaFromBody(string? contentType, string body) - { - if (contentType is null) - { - Logger.LogDebug(" No content type to process"); - return null; - } - - if (contentType.StartsWith("application/json")) - { - Logger.LogDebug(" Processing JSON body..."); - return GetSchemaFromJsonString(body); - } - - return null; - } - - private void AddOrMergePathItem(IList openApiDocs, OpenApiPathItem pathItem, Uri requestUri, string parametrizedPath) - { - var serverUrl = requestUri.GetLeftPart(UriPartial.Authority); - var openApiDoc = openApiDocs.FirstOrDefault(d => d.Servers.Any(s => s.Url == serverUrl)); - - if (openApiDoc is null) - { - Logger.LogDebug(" Creating OpenAPI spec for {serverUrl}...", serverUrl); - - openApiDoc = new OpenApiDocument - { - Info = new OpenApiInfo - { - Version = "v1.0", - Title = $"{serverUrl} API", - Description = $"{serverUrl} API", - }, - Servers = - [ - new OpenApiServer { Url = serverUrl } - ], - Paths = [], - Extensions = new Dictionary - { - { "x-ms-generated-by", new GeneratedByOpenApiExtension() } - } - }; - openApiDocs.Add(openApiDoc); - } - else - { - Logger.LogDebug(" Found OpenAPI spec for {serverUrl}...", serverUrl); - } - - if (!openApiDoc.Paths.TryGetValue(parametrizedPath, out OpenApiPathItem? value)) - { - Logger.LogDebug(" Adding path {parametrizedPath} to OpenAPI spec...", parametrizedPath); - value = pathItem; - openApiDoc.Paths.Add(parametrizedPath, value); - // since we've just added the path, we're done - return; - } - - Logger.LogDebug(" Merging path {parametrizedPath} into OpenAPI spec...", parametrizedPath); - var operation = pathItem.Operations.First(); - AddOrMergeOperation(value, operation.Key, operation.Value); - } - - private void AddOrMergeOperation(OpenApiPathItem pathItem, OperationType operationType, OpenApiOperation apiOperation) - { - if (!pathItem.Operations.TryGetValue(operationType, out OpenApiOperation? value)) - { - Logger.LogDebug(" Adding operation {operationType} to path...", operationType); - - pathItem.AddOperation(operationType, apiOperation); - // since we've just added the operation, we're done - return; - } - - Logger.LogDebug(" Merging operation {operationType} into path...", operationType); - - var operation = value; - - AddOrMergeParameters(operation, apiOperation.Parameters); - AddOrMergeRequestBody(operation, apiOperation.RequestBody); - AddOrMergeResponse(operation, apiOperation.Responses); - } - - private void AddOrMergeParameters(OpenApiOperation operation, IList parameters) - { - if (parameters is null || !parameters.Any()) - { - Logger.LogDebug(" No parameters to process"); - return; - } - - Logger.LogDebug(" Processing parameters for operation..."); - - foreach (var parameter in parameters) - { - var paramFromOperation = operation.Parameters.FirstOrDefault(p => p.Name == parameter.Name && p.In == parameter.In); - if (paramFromOperation is null) - { - Logger.LogDebug(" Adding parameter {parameterName} to operation...", parameter.Name); - operation.Parameters.Add(parameter); - continue; - } - - Logger.LogDebug(" Merging parameter {parameterName}...", parameter.Name); - MergeSchema(parameter?.Schema, paramFromOperation?.Schema); - } - } - - private void MergeSchema(OpenApiSchema? source, OpenApiSchema? target) - { - if (source is null || target is null) - { - Logger.LogDebug(" Source or target is null. Skipping..."); - return; - } - - if (source.Type != "object" || target.Type != "object") - { - Logger.LogDebug(" Source or target schema is not an object. Skipping..."); - return; - } - - if (source.Properties is null || !source.Properties.Any()) - { - Logger.LogDebug(" Source has no properties. Skipping..."); - return; - } - - if (target.Properties is null || !target.Properties.Any()) - { - Logger.LogDebug(" Target has no properties. Skipping..."); - return; - } - - foreach (var property in source.Properties) - { - var propertyFromTarget = target.Properties.FirstOrDefault(p => p.Key == property.Key); - if (propertyFromTarget.Value is null) - { - Logger.LogDebug(" Adding property {propertyKey} to schema...", property.Key); - target.Properties.Add(property); - continue; - } - - if (property.Value.Type != "object") - { - Logger.LogDebug(" Property already found but is not an object. Skipping..."); - continue; - } - - Logger.LogDebug(" Merging property {propertyKey}...", property.Key); - MergeSchema(property.Value, propertyFromTarget.Value); - } - } - - private void AddOrMergeRequestBody(OpenApiOperation operation, OpenApiRequestBody requestBody) - { - if (requestBody is null || !requestBody.Content.Any()) - { - Logger.LogDebug(" No request body to process"); - return; - } - - var requestBodyType = requestBody.Content.FirstOrDefault().Key; - operation.RequestBody.Content.TryGetValue(requestBodyType, out OpenApiMediaType? bodyFromOperation); - - if (bodyFromOperation is null) - { - Logger.LogDebug(" Adding request body to operation..."); - - operation.RequestBody.Content.Add(requestBody.Content.FirstOrDefault()); - // since we've just added the request body, we're done - return; - } - - Logger.LogDebug(" Merging request body into operation..."); - MergeSchema(bodyFromOperation.Schema, requestBody.Content.FirstOrDefault().Value.Schema); - } - - private void AddOrMergeResponse(OpenApiOperation operation, OpenApiResponses apiResponses) - { - if (apiResponses is null) - { - Logger.LogDebug(" No response to process"); - return; - } - - var apiResponseInfo = apiResponses.FirstOrDefault(); - var apiResponseStatusCode = apiResponseInfo.Key; - var apiResponse = apiResponseInfo.Value; - operation.Responses.TryGetValue(apiResponseStatusCode, out OpenApiResponse? responseFromOperation); - - if (responseFromOperation is null) - { - Logger.LogDebug(" Adding response {apiResponseStatusCode} to operation...", apiResponseStatusCode); - - operation.Responses.Add(apiResponseStatusCode, apiResponse); - // since we've just added the response, we're done - return; - } - - if (!apiResponse.Content.Any()) - { - Logger.LogDebug(" No response content to process"); - return; - } - - var apiResponseContentType = apiResponse.Content.First().Key; - responseFromOperation.Content.TryGetValue(apiResponseContentType, out OpenApiMediaType? contentFromOperation); - - if (contentFromOperation is null) - { - Logger.LogDebug(" Adding response {apiResponseContentType} to {apiResponseStatusCode} to response...", apiResponseContentType, apiResponseStatusCode); - - responseFromOperation.Content.Add(apiResponse.Content.First()); - // since we've just added the content, we're done - return; - } - - Logger.LogDebug(" Merging response {apiResponseStatusCode}/{apiResponseContentType} into operation...", apiResponseStatusCode, apiResponseContentType); - MergeSchema(contentFromOperation.Schema, apiResponse.Content.First().Value.Schema); - } - - private static string GetFileNameFromServerUrl(string serverUrl) - { - var uri = new Uri(serverUrl); - var fileName = $"{uri.Host}-{DateTime.Now:yyyyMMddHHmmss}.json"; - return fileName; - } - - private static OpenApiSchema GetSchemaFromJsonString(string jsonString) - { - try - { - using var doc = JsonDocument.Parse(jsonString); - JsonElement root = doc.RootElement; - var schema = GetSchemaFromJsonElement(root); - return schema; - } - catch - { - return new OpenApiSchema - { - Type = "object" - }; - } - } - - private static OpenApiSchema GetSchemaFromJsonElement(JsonElement jsonElement) - { - var schema = new OpenApiSchema(); - - switch (jsonElement.ValueKind) - { - case JsonValueKind.String: - schema.Type = "string"; - break; - case JsonValueKind.Number: - schema.Type = "number"; - break; - case JsonValueKind.True: - case JsonValueKind.False: - schema.Type = "boolean"; - break; - case JsonValueKind.Object: - schema.Type = "object"; - schema.Properties = jsonElement.EnumerateObject() - .ToDictionary(p => p.Name, p => GetSchemaFromJsonElement(p.Value)); - break; - case JsonValueKind.Array: - schema.Type = "array"; - schema.Items = GetSchemaFromJsonElement(jsonElement.EnumerateArray().FirstOrDefault()); - break; - default: - schema.Type = "object"; - break; - } - - return schema; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Configuration; +using Microsoft.DevProxy.Abstractions; +using Titanium.Web.Proxy.EventArguments; +using Microsoft.OpenApi.Models; +using Microsoft.OpenApi.Extensions; +using System.Text.Json; +using Microsoft.OpenApi.Interfaces; +using Microsoft.OpenApi.Writers; +using Microsoft.OpenApi; +using Titanium.Web.Proxy.Http; +using System.Web; +using System.Collections.Specialized; +using Microsoft.Extensions.Logging; +using Microsoft.DevProxy.Abstractions.LanguageModel; + +namespace Microsoft.DevProxy.Plugins.RequestLogs; + +public class OpenApiSpecGeneratorPluginReportItem +{ + public required string ServerUrl { get; init; } + public required string FileName { get; init; } +} + +public class OpenApiSpecGeneratorPluginReport : List +{ + public OpenApiSpecGeneratorPluginReport() : base() { } + + public OpenApiSpecGeneratorPluginReport(IEnumerable collection) : base(collection) { } +} + +class GeneratedByOpenApiExtension : IOpenApiExtension +{ + public void Write(IOpenApiWriter writer, OpenApiSpecVersion specVersion) + { + writer.WriteStartObject(); + writer.WriteProperty("toolName", "Dev Proxy"); + writer.WriteProperty("toolVersion", ProxyUtils.ProductVersion); + writer.WriteEndObject(); + } +} + +internal class OpenApiSpecGeneratorPluginConfiguration +{ + public bool IncludeOptionsRequests { get; set; } = false; +} + +public class OpenApiSpecGeneratorPlugin(IPluginEvents pluginEvents, IProxyContext context, ILogger logger, ISet urlsToWatch, IConfigurationSection? configSection = null) : BaseReportingPlugin(pluginEvents, context, logger, urlsToWatch, configSection) +{ + // from: https://github.com/jonluca/har-to-openapi/blob/0d44409162c0a127cdaccd60b0a270ecd361b829/src/utils/headers.ts + private static readonly string[] standardHeaders = + [ + ":authority", + ":method", + ":path", + ":scheme", + ":status", + "a-im", + "accept", + "accept-additions", + "accept-ch", + "accept-ch-lifetime", + "accept-charset", + "accept-datetime", + "accept-encoding", + "accept-features", + "accept-language", + "accept-patch", + "accept-post", + "accept-ranges", + "access-control-allow-credentials", + "access-control-allow-headers", + "access-control-allow-methods", + "access-control-allow-origin", + "access-control-expose-headers", + "access-control-max-age", + "access-control-request-headers", + "access-control-request-method", + "age", + "allow", + "alpn", + "alt-svc", + "alternate-protocol", + "alternates", + "amp-access-control-allow-source-origin", + "apply-to-redirect-ref", + "authentication-info", + "authorization", + "c-ext", + "c-man", + "c-opt", + "c-pep", + "c-pep-info", + "cache-control", + "ch", + "connection", + "content-base", + "content-disposition", + "content-dpr", + "content-encoding", + "content-id", + "content-language", + "content-length", + "content-location", + "content-md5", + "content-range", + "content-script-type", + "content-security-policy", + "content-security-policy-report-only", + "content-style-type", + "content-type", + "content-version", + "cookie", + "cookie2", + "cross-origin-resource-policy", + "dasl", + "date", + "dav", + "default-style", + "delta-base", + "depth", + "derived-from", + "destination", + "differential-id", + "digest", + "dnt", + "dpr", + "encryption", + "encryption-key", + "etag", + "expect", + "expect-ct", + "expires", + "ext", + "forwarded", + "from", + "front-end-https", + "getprofile", + "host", + "http2-settings", + "if", + "if-match", + "if-modified-since", + "if-none-match", + "if-range", + "if-schedule-tag-match", + "if-unmodified-since", + "im", + "keep-alive", + "key", + "label", + "last-event-id", + "last-modified", + "link", + "link-template", + "location", + "lock-token", + "man", + "max-forwards", + "md", + "meter", + "mime-version", + "negotiate", + "nice", + "opt", + "ordering-type", + "origin", + "origin-trial", + "overwrite", + "p3p", + "pep", + "pep-info", + "pics-label", + "poe", + "poe-links", + "position", + "pragma", + "prefer", + "preference-applied", + "profileobject", + "protocol", + "protocol-info", + "protocol-query", + "protocol-request", + "proxy-authenticate", + "proxy-authentication-info", + "proxy-authorization", + "proxy-connection", + "proxy-features", + "proxy-instruction", + "public", + "range", + "redirect-ref", + "referer", + "referrer-policy", + "report-to", + "retry-after", + "rw", + "safe", + "save-data", + "schedule-reply", + "schedule-tag", + "sec-ch-ua", + "sec-ch-ua-mobile", + "sec-ch-ua-platform", + "sec-fetch-dest", + "sec-fetch-mode", + "sec-fetch-site", + "sec-fetch-user", + "sec-websocket-accept", + "sec-websocket-extensions", + "sec-websocket-key", + "sec-websocket-protocol", + "sec-websocket-version", + "security-scheme", + "server", + "server-timing", + "set-cookie", + "set-cookie2", + "setprofile", + "slug", + "soapaction", + "status-uri", + "strict-transport-security", + "sunset", + "surrogate-capability", + "surrogate-control", + "tcn", + "te", + "timeout", + "timing-allow-origin", + "tk", + "trailer", + "transfer-encoding", + "upgrade", + "upgrade-insecure-requests", + "uri", + "user-agent", + "variant-vary", + "vary", + "via", + "want-digest", + "warning", + "www-authenticate", + "x-att-deviceid", + "x-csrf-token", + "x-forwarded-for", + "x-forwarded-host", + "x-forwarded-proto", + "x-frame-options", + "x-frontend", + "x-http-method-override", + "x-powered-by", + "x-request-id", + "x-requested-with", + "x-uidh", + "x-wap-profile", + "x-xss-protection" + ]; + private static readonly string[] authHeaders = + [ + "access-token", + "api-key", + "auth-token", + "authorization", + "authorization-token", + "cookie", + "key", + "token", + "x-access-token", + "x-access-token", + "x-api-key", + "x-auth", + "x-auth-token", + "x-csrf-token", + "secret", + "x-secret", + "access-key", + "api-key", + "apikey" + ]; + + public override string Name => nameof(OpenApiSpecGeneratorPlugin); + private readonly OpenApiSpecGeneratorPluginConfiguration _configuration = new(); + public static readonly string GeneratedOpenApiSpecsKey = "GeneratedOpenApiSpecs"; + + public override async Task RegisterAsync() + { + await base.RegisterAsync(); + + ConfigSection?.Bind(_configuration); + + PluginEvents.AfterRecordingStop += AfterRecordingStopAsync; + } + + private async Task AfterRecordingStopAsync(object? sender, RecordingArgs e) + { + Logger.LogInformation("Creating OpenAPI spec from recorded requests..."); + + if (!e.RequestLogs.Any()) + { + Logger.LogDebug("No requests to process"); + return; + } + + var openApiDocs = new List(); + + foreach (var request in e.RequestLogs) + { + if (request.MessageType != MessageType.InterceptedResponse || + request.Context is null || + request.Context.Session is null) + { + continue; + } + + if (!_configuration.IncludeOptionsRequests && + string.Equals(request.Context.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + { + Logger.LogDebug("Skipping OPTIONS request {url}...", request.Context.Session.HttpClient.Request.RequestUri); + continue; + } + + var methodAndUrlString = request.Message.First(); + Logger.LogDebug("Processing request {methodAndUrlString}...", methodAndUrlString); + + try + { + var pathItem = GetOpenApiPathItem(request.Context.Session); + var parametrizedPath = ParametrizePath(pathItem, request.Context.Session.HttpClient.Request.RequestUri); + var operationInfo = pathItem.Operations.First(); + operationInfo.Value.OperationId = await GetOperationIdAsync( + operationInfo.Key.ToString(), + request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority), + parametrizedPath + ); + operationInfo.Value.Description = await GetOperationDescriptionAsync( + operationInfo.Key.ToString(), + request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority), + parametrizedPath + ); + AddOrMergePathItem(openApiDocs, pathItem, request.Context.Session.HttpClient.Request.RequestUri, parametrizedPath); + } + catch (Exception ex) + { + Logger.LogError(ex, "Error processing request {methodAndUrl}", methodAndUrlString); + } + } + + Logger.LogDebug("Serializing OpenAPI docs..."); + var generatedOpenApiSpecs = new Dictionary(); + foreach (var openApiDoc in openApiDocs) + { + var server = openApiDoc.Servers.First(); + var fileName = GetFileNameFromServerUrl(server.Url); + var docString = openApiDoc.SerializeAsJson(OpenApiSpecVersion.OpenApi3_0); + + Logger.LogDebug(" Writing OpenAPI spec to {fileName}...", fileName); + File.WriteAllText(fileName, docString); + + generatedOpenApiSpecs.Add(server.Url, fileName); + + Logger.LogInformation("Created OpenAPI spec file {fileName}", fileName); + } + + StoreReport(new OpenApiSpecGeneratorPluginReport( + generatedOpenApiSpecs + .Select(kvp => new OpenApiSpecGeneratorPluginReportItem + { + ServerUrl = kvp.Key, + FileName = kvp.Value + })), e); + + // store the generated OpenAPI specs in the global data + // for use by other plugins + e.GlobalData[GeneratedOpenApiSpecsKey] = generatedOpenApiSpecs; + } + + /** + * Replaces segments in the request URI, that match predefined patters, + * with parameters and adds them to the OpenAPI PathItem. + * @param pathItem The OpenAPI PathItem to parametrize. + * @param requestUri The request URI. + * @returns The parametrized server-relative URL + */ + private static string ParametrizePath(OpenApiPathItem pathItem, Uri requestUri) + { + var segments = requestUri.Segments; + var previousSegment = "item"; + + for (var i = 0; i < segments.Length; i++) + { + var segment = requestUri.Segments[i].Trim('/'); + if (string.IsNullOrEmpty(segment)) + { + continue; + } + + if (IsParametrizable(segment)) + { + var parameterName = $"{previousSegment}-id"; + segments[i] = $"{{{parameterName}}}{(requestUri.Segments[i].EndsWith('/') ? "/" : "")}"; + + pathItem.Parameters.Add(new OpenApiParameter + { + Name = parameterName, + In = ParameterLocation.Path, + Required = true, + Schema = new OpenApiSchema { Type = "string" } + }); + } + else + { + previousSegment = segment; + } + } + + return string.Join(string.Empty, segments); + } + + private static bool IsParametrizable(string segment) + { + return Guid.TryParse(segment.Trim('/'), out _) || + int.TryParse(segment.Trim('/'), out _); + } + + private static string GetLastNonTokenSegment(string[] segments) + { + for (var i = segments.Length - 1; i >= 0; i--) + { + var segment = segments[i].Trim('/'); + if (string.IsNullOrEmpty(segment)) + { + continue; + } + + if (!IsParametrizable(segment)) + { + return segment; + } + } + + return "item"; + } + + private async Task GetOperationIdAsync(string method, string serverUrl, string parametrizedPath) + { + var prompt = $"For the specified request, generate an operation ID, compatible with an OpenAPI spec. Respond with just the ID in plain-text format. For example, for request such as `GET https://api.contoso.com/books/{{books-id}}` you return `getBookById`. For a request like `GET https://api.contoso.com/books/{{books-id}}/authors` you return `getAuthorsForBookById`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}"; + ILanguageModelCompletionResponse? id = null; + if (await Context.LanguageModelClient.IsEnabledAsync()) { + id = await Context.LanguageModelClient.GenerateCompletionAsync(prompt); + } + return id?.Response ?? $"{method}{parametrizedPath.Replace('/', '.')}"; + } + + private async Task GetOperationDescriptionAsync(string method, string serverUrl, string parametrizedPath) + { + var prompt = $"You're an expert in OpenAPI. You help developers build great OpenAPI specs for use with LLMs. For the specified request, generate a one-sentence description. Respond with just the description. For example, for a request such as `GET https://api.contoso.com/books/{{books-id}}` you return `Get a book by ID`. Request: {method.ToUpper()} {serverUrl}{parametrizedPath}"; + ILanguageModelCompletionResponse? description = null; + if (await Context.LanguageModelClient.IsEnabledAsync()) { + description = await Context.LanguageModelClient.GenerateCompletionAsync(prompt); + } + return description?.Response ?? $"{method} {parametrizedPath}"; + } + + /** + * Creates an OpenAPI PathItem from an intercepted request and response pair. + * @param session The intercepted session. + */ + private OpenApiPathItem GetOpenApiPathItem(SessionEventArgs session) + { + var request = session.HttpClient.Request; + var response = session.HttpClient.Response; + + var resource = GetLastNonTokenSegment(request.RequestUri.Segments); + var path = new OpenApiPathItem(); + + var method = request.Method?.ToUpperInvariant() switch + { + "DELETE" => OperationType.Delete, + "GET" => OperationType.Get, + "HEAD" => OperationType.Head, + "OPTIONS" => OperationType.Options, + "PATCH" => OperationType.Patch, + "POST" => OperationType.Post, + "PUT" => OperationType.Put, + "TRACE" => OperationType.Trace, + _ => throw new NotSupportedException($"Method {request.Method} is not supported") + }; + var operation = new OpenApiOperation + { + // will be replaced later after the path has been parametrized + Description = $"{method} {resource}", + // will be replaced later after the path has been parametrized + OperationId = $"{method}.{resource}" + }; + SetParametersFromQueryString(operation, HttpUtility.ParseQueryString(request.RequestUri.Query)); + SetParametersFromRequestHeaders(operation, request.Headers); + SetRequestBody(operation, request); + SetResponseFromSession(operation, response); + + path.Operations.Add(method, operation); + + return path; + } + + private void SetRequestBody(OpenApiOperation operation, Request request) + { + if (!request.HasBody) + { + Logger.LogDebug(" Request has no body"); + return; + } + + if (request.ContentType is null) + { + Logger.LogDebug(" Request has no content type"); + return; + } + + Logger.LogDebug(" Processing request body..."); + operation.RequestBody = new OpenApiRequestBody + { + Content = new Dictionary + { + { + GetMediaType(request.ContentType), + new OpenApiMediaType + { + Schema = GetSchemaFromBody(GetMediaType(request.ContentType), request.BodyString) + } + } + } + }; + } + + private void SetParametersFromRequestHeaders(OpenApiOperation operation, HeaderCollection headers) + { + if (headers is null || + !headers.Any()) + { + Logger.LogDebug(" Request has no headers"); + return; + } + + Logger.LogDebug(" Processing request headers..."); + foreach (var header in headers) + { + var lowerCaseHeaderName = header.Name.ToLowerInvariant(); + if (standardHeaders.Contains(lowerCaseHeaderName)) + { + Logger.LogDebug(" Skipping standard header {headerName}", header.Name); + continue; + } + + if (authHeaders.Contains(lowerCaseHeaderName)) + { + Logger.LogDebug(" Skipping auth header {headerName}", header.Name); + continue; + } + + operation.Parameters.Add(new OpenApiParameter + { + Name = header.Name, + In = ParameterLocation.Header, + Required = false, + Schema = new OpenApiSchema { Type = "string" } + }); + Logger.LogDebug(" Added header {headerName}", header.Name); + } + } + + private void SetParametersFromQueryString(OpenApiOperation operation, NameValueCollection queryParams) + { + if (queryParams.AllKeys is null || + queryParams.AllKeys.Length == 0) + { + Logger.LogDebug(" Request has no query string parameters"); + return; + } + + Logger.LogDebug(" Processing query string parameters..."); + var dictionary = (queryParams.AllKeys as string[]).ToDictionary(k => k, k => queryParams[k] as object); + + foreach (var parameter in dictionary) + { + operation.Parameters.Add(new OpenApiParameter + { + Name = parameter.Key, + In = ParameterLocation.Query, + Required = false, + Schema = new OpenApiSchema { Type = "string" } + }); + Logger.LogDebug(" Added query string parameter {parameterKey}", parameter.Key); + } + } + + private void SetResponseFromSession(OpenApiOperation operation, Response response) + { + if (response is null) + { + Logger.LogDebug(" No response to process"); + return; + } + + Logger.LogDebug(" Processing response..."); + + var openApiResponse = new OpenApiResponse + { + Description = response.StatusDescription + }; + var responseCode = response.StatusCode.ToString(); + if (response.HasBody) + { + Logger.LogDebug(" Response has body"); + + openApiResponse.Content.Add(GetMediaType(response.ContentType), new OpenApiMediaType + { + Schema = GetSchemaFromBody(GetMediaType(response.ContentType), response.BodyString) + }); + } + else + { + Logger.LogDebug(" Response doesn't have body"); + } + + if (response.Headers is not null && response.Headers.Any()) + { + Logger.LogDebug(" Response has headers"); + + foreach (var header in response.Headers) + { + var lowerCaseHeaderName = header.Name.ToLowerInvariant(); + if (standardHeaders.Contains(lowerCaseHeaderName)) + { + Logger.LogDebug(" Skipping standard header {headerName}", header.Name); + continue; + } + + if (authHeaders.Contains(lowerCaseHeaderName)) + { + Logger.LogDebug(" Skipping auth header {headerName}", header.Name); + continue; + } + + if (openApiResponse.Headers.ContainsKey(header.Name)) + { + Logger.LogDebug(" Header {headerName} already exists in response", header.Name); + continue; + } + + openApiResponse.Headers.Add(header.Name, new OpenApiHeader + { + Schema = new OpenApiSchema { Type = "string" } + }); + Logger.LogDebug(" Added header {headerName}", header.Name); + } + } + else + { + Logger.LogDebug(" Response doesn't have headers"); + } + + operation.Responses.Add(responseCode, openApiResponse); + } + + private static string GetMediaType(string? contentType) + { + if (string.IsNullOrEmpty(contentType)) + { + return contentType ?? ""; + } + + var mediaType = contentType.Split(';').First().Trim(); + return mediaType; + } + + private OpenApiSchema? GetSchemaFromBody(string? contentType, string body) + { + if (contentType is null) + { + Logger.LogDebug(" No content type to process"); + return null; + } + + if (contentType.StartsWith("application/json")) + { + Logger.LogDebug(" Processing JSON body..."); + return GetSchemaFromJsonString(body); + } + + return null; + } + + private void AddOrMergePathItem(IList openApiDocs, OpenApiPathItem pathItem, Uri requestUri, string parametrizedPath) + { + var serverUrl = requestUri.GetLeftPart(UriPartial.Authority); + var openApiDoc = openApiDocs.FirstOrDefault(d => d.Servers.Any(s => s.Url == serverUrl)); + + if (openApiDoc is null) + { + Logger.LogDebug(" Creating OpenAPI spec for {serverUrl}...", serverUrl); + + openApiDoc = new OpenApiDocument + { + Info = new OpenApiInfo + { + Version = "v1.0", + Title = $"{serverUrl} API", + Description = $"{serverUrl} API", + }, + Servers = + [ + new OpenApiServer { Url = serverUrl } + ], + Paths = [], + Extensions = new Dictionary + { + { "x-ms-generated-by", new GeneratedByOpenApiExtension() } + } + }; + openApiDocs.Add(openApiDoc); + } + else + { + Logger.LogDebug(" Found OpenAPI spec for {serverUrl}...", serverUrl); + } + + if (!openApiDoc.Paths.TryGetValue(parametrizedPath, out OpenApiPathItem? value)) + { + Logger.LogDebug(" Adding path {parametrizedPath} to OpenAPI spec...", parametrizedPath); + value = pathItem; + openApiDoc.Paths.Add(parametrizedPath, value); + // since we've just added the path, we're done + return; + } + + Logger.LogDebug(" Merging path {parametrizedPath} into OpenAPI spec...", parametrizedPath); + var operation = pathItem.Operations.First(); + AddOrMergeOperation(value, operation.Key, operation.Value); + } + + private void AddOrMergeOperation(OpenApiPathItem pathItem, OperationType operationType, OpenApiOperation apiOperation) + { + if (!pathItem.Operations.TryGetValue(operationType, out OpenApiOperation? value)) + { + Logger.LogDebug(" Adding operation {operationType} to path...", operationType); + + pathItem.AddOperation(operationType, apiOperation); + // since we've just added the operation, we're done + return; + } + + Logger.LogDebug(" Merging operation {operationType} into path...", operationType); + + var operation = value; + + AddOrMergeParameters(operation, apiOperation.Parameters); + AddOrMergeRequestBody(operation, apiOperation.RequestBody); + AddOrMergeResponse(operation, apiOperation.Responses); + } + + private void AddOrMergeParameters(OpenApiOperation operation, IList parameters) + { + if (parameters is null || !parameters.Any()) + { + Logger.LogDebug(" No parameters to process"); + return; + } + + Logger.LogDebug(" Processing parameters for operation..."); + + foreach (var parameter in parameters) + { + var paramFromOperation = operation.Parameters.FirstOrDefault(p => p.Name == parameter.Name && p.In == parameter.In); + if (paramFromOperation is null) + { + Logger.LogDebug(" Adding parameter {parameterName} to operation...", parameter.Name); + operation.Parameters.Add(parameter); + continue; + } + + Logger.LogDebug(" Merging parameter {parameterName}...", parameter.Name); + MergeSchema(parameter?.Schema, paramFromOperation?.Schema); + } + } + + private void MergeSchema(OpenApiSchema? source, OpenApiSchema? target) + { + if (source is null || target is null) + { + Logger.LogDebug(" Source or target is null. Skipping..."); + return; + } + + if (source.Type != "object" || target.Type != "object") + { + Logger.LogDebug(" Source or target schema is not an object. Skipping..."); + return; + } + + if (source.Properties is null || !source.Properties.Any()) + { + Logger.LogDebug(" Source has no properties. Skipping..."); + return; + } + + if (target.Properties is null || !target.Properties.Any()) + { + Logger.LogDebug(" Target has no properties. Skipping..."); + return; + } + + foreach (var property in source.Properties) + { + var propertyFromTarget = target.Properties.FirstOrDefault(p => p.Key == property.Key); + if (propertyFromTarget.Value is null) + { + Logger.LogDebug(" Adding property {propertyKey} to schema...", property.Key); + target.Properties.Add(property); + continue; + } + + if (property.Value.Type != "object") + { + Logger.LogDebug(" Property already found but is not an object. Skipping..."); + continue; + } + + Logger.LogDebug(" Merging property {propertyKey}...", property.Key); + MergeSchema(property.Value, propertyFromTarget.Value); + } + } + + private void AddOrMergeRequestBody(OpenApiOperation operation, OpenApiRequestBody requestBody) + { + if (requestBody is null || !requestBody.Content.Any()) + { + Logger.LogDebug(" No request body to process"); + return; + } + + var requestBodyType = requestBody.Content.FirstOrDefault().Key; + operation.RequestBody.Content.TryGetValue(requestBodyType, out OpenApiMediaType? bodyFromOperation); + + if (bodyFromOperation is null) + { + Logger.LogDebug(" Adding request body to operation..."); + + operation.RequestBody.Content.Add(requestBody.Content.FirstOrDefault()); + // since we've just added the request body, we're done + return; + } + + Logger.LogDebug(" Merging request body into operation..."); + MergeSchema(bodyFromOperation.Schema, requestBody.Content.FirstOrDefault().Value.Schema); + } + + private void AddOrMergeResponse(OpenApiOperation operation, OpenApiResponses apiResponses) + { + if (apiResponses is null) + { + Logger.LogDebug(" No response to process"); + return; + } + + var apiResponseInfo = apiResponses.FirstOrDefault(); + var apiResponseStatusCode = apiResponseInfo.Key; + var apiResponse = apiResponseInfo.Value; + operation.Responses.TryGetValue(apiResponseStatusCode, out OpenApiResponse? responseFromOperation); + + if (responseFromOperation is null) + { + Logger.LogDebug(" Adding response {apiResponseStatusCode} to operation...", apiResponseStatusCode); + + operation.Responses.Add(apiResponseStatusCode, apiResponse); + // since we've just added the response, we're done + return; + } + + if (!apiResponse.Content.Any()) + { + Logger.LogDebug(" No response content to process"); + return; + } + + var apiResponseContentType = apiResponse.Content.First().Key; + responseFromOperation.Content.TryGetValue(apiResponseContentType, out OpenApiMediaType? contentFromOperation); + + if (contentFromOperation is null) + { + Logger.LogDebug(" Adding response {apiResponseContentType} to {apiResponseStatusCode} to response...", apiResponseContentType, apiResponseStatusCode); + + responseFromOperation.Content.Add(apiResponse.Content.First()); + // since we've just added the content, we're done + return; + } + + Logger.LogDebug(" Merging response {apiResponseStatusCode}/{apiResponseContentType} into operation...", apiResponseStatusCode, apiResponseContentType); + MergeSchema(contentFromOperation.Schema, apiResponse.Content.First().Value.Schema); + } + + private static string GetFileNameFromServerUrl(string serverUrl) + { + var uri = new Uri(serverUrl); + var fileName = $"{uri.Host}-{DateTime.Now:yyyyMMddHHmmss}.json"; + return fileName; + } + + private static OpenApiSchema GetSchemaFromJsonString(string jsonString) + { + try + { + using var doc = JsonDocument.Parse(jsonString); + JsonElement root = doc.RootElement; + var schema = GetSchemaFromJsonElement(root); + return schema; + } + catch + { + return new OpenApiSchema + { + Type = "object" + }; + } + } + + private static OpenApiSchema GetSchemaFromJsonElement(JsonElement jsonElement) + { + var schema = new OpenApiSchema(); + + switch (jsonElement.ValueKind) + { + case JsonValueKind.String: + schema.Type = "string"; + break; + case JsonValueKind.Number: + schema.Type = "number"; + break; + case JsonValueKind.True: + case JsonValueKind.False: + schema.Type = "boolean"; + break; + case JsonValueKind.Object: + schema.Type = "object"; + schema.Properties = jsonElement.EnumerateObject() + .ToDictionary(p => p.Name, p => GetSchemaFromJsonElement(p.Value)); + break; + case JsonValueKind.Array: + schema.Type = "array"; + schema.Items = GetSchemaFromJsonElement(jsonElement.EnumerateArray().FirstOrDefault()); + break; + default: + schema.Type = "object"; + break; + } + + return schema; + } +} diff --git a/dev-proxy/Logging/ProxyConsoleFormatter.cs b/dev-proxy/Logging/ProxyConsoleFormatter.cs index 670a1748..26c8eddc 100644 --- a/dev-proxy/Logging/ProxyConsoleFormatter.cs +++ b/dev-proxy/Logging/ProxyConsoleFormatter.cs @@ -1,262 +1,262 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Text; -using Microsoft.DevProxy.Abstractions; -using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.Extensions.Logging.Console; -using Microsoft.Extensions.Options; - -namespace Microsoft.DevProxy.Logging; - -public class ProxyConsoleFormatter : ConsoleFormatter -{ - private const string _boxTopLeft = "\u256d "; - private const string _boxLeft = "\u2502 "; - private const string _boxBottomLeft = "\u2570 "; - // used to align single-line messages - private const string _boxSpacing = " "; - private readonly Dictionary> _requestLogs = []; - private readonly ConsoleFormatterOptions _options; - const string labelSpacing = " "; - // label length + 2 - private readonly static string noLabelSpacing = new(' ', 4 + 2); - - public ProxyConsoleFormatter(IOptions options) : base("devproxy") - { - // needed to properly required rounded corners in the box - Console.OutputEncoding = Encoding.UTF8; - _options = options.Value; - } - - public override void Write(in LogEntry logEntry, IExternalScopeProvider? scopeProvider, TextWriter textWriter) - { - if (logEntry.State is RequestLog requestLog) - { - LogRequest(requestLog, scopeProvider, textWriter); - } - else - { - LogMessage(logEntry, scopeProvider, textWriter); - } - } - - private void LogRequest(RequestLog requestLog, IExternalScopeProvider? scopeProvider, TextWriter textWriter) - { - var messageType = requestLog.MessageType; - - // don't log intercepted response to console - if (messageType == MessageType.InterceptedResponse) - { - return; - } - - var requestId = GetRequestIdScope(scopeProvider); - - if (requestId is not null) - { - if (messageType == MessageType.FinishedProcessingRequest) - { - var lastMessage = _requestLogs[requestId.Value].Last(); - // log all request logs for the request - foreach (var log in _requestLogs[requestId.Value]) - { - WriteLogMessageBoxedWithInvertedLabels(log.MessageLines, log.MessageType, textWriter, log == lastMessage); - } - _requestLogs.Remove(requestId.Value); - } - else - { - // buffer request logs until the request is finished processing - if (!_requestLogs.ContainsKey(requestId.Value)) - { - _requestLogs[requestId.Value] = []; - } - _requestLogs[requestId.Value].Add(requestLog); - } - } - } - - private static int? GetRequestIdScope(IExternalScopeProvider? scopeProvider) - { - int? requestId = null; - - scopeProvider?.ForEachScope((scope, state) => - { - if (scope is Dictionary dictionary) - { - if (dictionary.TryGetValue(nameof(requestId), out var req)) - { - requestId = (int)req; - } - } - }, ""); - - return requestId; - } - - private void LogMessage(in LogEntry logEntry, IExternalScopeProvider? scopeProvider, TextWriter textWriter) - { - // regular messages - var logLevel = logEntry.LogLevel; - var message = logEntry.Formatter(logEntry.State, logEntry.Exception); - - WriteMessageBoxedWithInvertedLabels(message, logLevel, scopeProvider, textWriter); - - if (logEntry.Exception is not null) - { - textWriter.Write($" Exception Details: {logEntry.Exception}"); - } - - textWriter.WriteLine(); - } - - private void WriteMessageBoxedWithInvertedLabels(string? message, LogLevel logLevel, IExternalScopeProvider? scopeProvider, TextWriter textWriter) - { - if (message is null) - { - return; - } - - var label = GetLogLevelString(logLevel); - var (bgColor, fgColor) = GetLogLevelColor(logLevel); - - textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); - textWriter.Write($"{labelSpacing}{_boxSpacing}{(logLevel == LogLevel.Debug ? $"[{DateTime.Now:T}] " : "")}"); - - if (_options.IncludeScopes && scopeProvider is not null) - { - scopeProvider.ForEachScope((scope, state) => - { - if (scope is null) - { - return; - } - - if (scope is string scopeString) - { - textWriter.Write(scopeString); - textWriter.Write(": "); - } - else if (scope.GetType().Name == "FormattedLogValues") - { - textWriter.Write(scope.ToString()); - textWriter.Write(": "); - } - }, textWriter); - } - - textWriter.Write(message); - } - - private static void WriteLogMessageBoxedWithInvertedLabels(string[] message, MessageType messageType, TextWriter textWriter, bool lastMessage = false) - { - var label = GetMessageTypeString(messageType); - var (bgColor, fgColor) = GetMessageTypeColor(messageType); - - switch (messageType) - { - case MessageType.InterceptedRequest: - // always one line (method + URL) - // print label and top of the box - textWriter.WriteLine(); - textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); - textWriter.WriteLine($"{(label.Length < 4 ? " " : "")}{labelSpacing}{_boxTopLeft}{message[0]}"); - break; - default: - if (message.Length == 1) - { - textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); - textWriter.WriteLine($"{(label.Length < 4 ? " " : "")}{labelSpacing}{(lastMessage ? _boxBottomLeft : _boxLeft)}{message[0]}"); - } - else - { - for (var i = 0; i < message.Length; i++) - { - if (i == 0) - { - // print label and middle of the box - textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); - textWriter.WriteLine($"{(label.Length < 4 ? " " : "")}{labelSpacing}{_boxLeft}{message[i]}"); - } - else if (i < message.Length - 1) - { - // print middle of the box - textWriter.WriteLine($"{noLabelSpacing}{labelSpacing}{_boxLeft}{message[i]}"); - } - else - { - // print end of the box - textWriter.WriteLine($"{noLabelSpacing}{labelSpacing}{(lastMessage ? _boxBottomLeft : _boxLeft)}{message[i]}"); - } - } - } - break; - } - } - - // from https://github.com/dotnet/runtime/blob/198a2596229f69b8e02902bfb4ffc2a30be3b339/src/libraries/Microsoft.Extensions.Logging.Console/src/SimpleConsoleFormatter.cs#L154 - private static string GetLogLevelString(LogLevel logLevel) - { - return logLevel switch - { - LogLevel.Trace => "trce", - LogLevel.Debug => "dbug", - LogLevel.Information => "info", - LogLevel.Warning => "warn", - LogLevel.Error => "fail", - LogLevel.Critical => "crit", - _ => throw new ArgumentOutOfRangeException(nameof(logLevel)) - }; - } - - private static (ConsoleColor bg, ConsoleColor fg) GetLogLevelColor(LogLevel logLevel) - { - var fgColor = Console.ForegroundColor; - var bgColor = Console.BackgroundColor; - - return logLevel switch - { - LogLevel.Information => (bgColor, ConsoleColor.Blue), - LogLevel.Warning => (ConsoleColor.DarkYellow, fgColor), - LogLevel.Error => (ConsoleColor.DarkRed, fgColor), - LogLevel.Debug => (bgColor, ConsoleColor.Gray), - LogLevel.Trace => (bgColor, ConsoleColor.Gray), - _ => (bgColor, fgColor) - }; - } - - private static string GetMessageTypeString(MessageType messageType) - { - return messageType switch - { - MessageType.InterceptedRequest => "req", - MessageType.InterceptedResponse => "res", - MessageType.PassedThrough => "api", - MessageType.Chaos => "oops", - MessageType.Warning => "warn", - MessageType.Mocked => "mock", - MessageType.Failed => "fail", - MessageType.Tip => "tip", - _ => " " - }; - } - - private static (ConsoleColor bg, ConsoleColor fg) GetMessageTypeColor(MessageType messageType) - { - var fgColor = Console.ForegroundColor; - var bgColor = Console.BackgroundColor; - - return messageType switch - { - MessageType.InterceptedRequest => (bgColor, ConsoleColor.Gray), - MessageType.PassedThrough => (ConsoleColor.Gray, fgColor), - MessageType.Chaos => (ConsoleColor.DarkRed, fgColor), - MessageType.Warning => (ConsoleColor.DarkYellow, fgColor), - MessageType.Mocked => (ConsoleColor.DarkMagenta, fgColor), - MessageType.Failed => (ConsoleColor.DarkRed, fgColor), - MessageType.Tip => (ConsoleColor.DarkBlue, fgColor), - _ => (bgColor, fgColor) - }; - } +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text; +using Microsoft.DevProxy.Abstractions; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging.Console; +using Microsoft.Extensions.Options; + +namespace Microsoft.DevProxy.Logging; + +public class ProxyConsoleFormatter : ConsoleFormatter +{ + private const string _boxTopLeft = "\u256d "; + private const string _boxLeft = "\u2502 "; + private const string _boxBottomLeft = "\u2570 "; + // used to align single-line messages + private const string _boxSpacing = " "; + private readonly Dictionary> _requestLogs = []; + private readonly ProxyConsoleFormatterOptions _options; + const string labelSpacing = " "; + // label length + 2 + private readonly static string noLabelSpacing = new(' ', 4 + 2); + public static readonly string DefaultCategoryName = "devproxy"; + + public ProxyConsoleFormatter(IOptions options) : base(DefaultCategoryName) + { + // needed to properly required rounded corners in the box + Console.OutputEncoding = Encoding.UTF8; + _options = options.Value; + } + + public override void Write(in LogEntry logEntry, IExternalScopeProvider? scopeProvider, TextWriter textWriter) + { + if (logEntry.State is RequestLog requestLog) + { + LogRequest(requestLog, logEntry.Category, scopeProvider, textWriter); + } + else + { + LogMessage(logEntry, scopeProvider, textWriter); + } + } + + private void LogRequest(RequestLog requestLog, string category, IExternalScopeProvider? scopeProvider, TextWriter textWriter) + { + var messageType = requestLog.MessageType; + + // don't log intercepted response to console + if (messageType == MessageType.InterceptedResponse || + (messageType == MessageType.Skipped && !_options.ShowSkipMessages)) + { + return; + } + + var requestId = GetRequestIdScope(scopeProvider); + + if (requestId is not null) + { + if (messageType == MessageType.FinishedProcessingRequest) + { + var lastMessage = _requestLogs[requestId.Value].Last(); + // log all request logs for the request + foreach (var log in _requestLogs[requestId.Value]) + { + WriteLogMessageBoxedWithInvertedLabels(log, scopeProvider, textWriter, log == lastMessage); + } + _requestLogs.Remove(requestId.Value); + } + else + { + // buffer request logs until the request is finished processing + if (!_requestLogs.TryGetValue(requestId.Value, out List? value)) + { + value = ([]); + _requestLogs[requestId.Value] = value; + } + + requestLog.PluginName = category == DefaultCategoryName ? null : category; + value.Add(requestLog); + } + } + } + + private static int? GetRequestIdScope(IExternalScopeProvider? scopeProvider) + { + int? requestId = null; + + scopeProvider?.ForEachScope((scope, state) => + { + if (scope is Dictionary dictionary) + { + if (dictionary.TryGetValue(nameof(requestId), out var req)) + { + requestId = (int)req; + } + } + }, ""); + + return requestId; + } + + private void LogMessage(in LogEntry logEntry, IExternalScopeProvider? scopeProvider, TextWriter textWriter) + { + // regular messages + var logLevel = logEntry.LogLevel; + var message = logEntry.Formatter(logEntry.State, logEntry.Exception); + var category = logEntry.Category == DefaultCategoryName ? null : logEntry.Category; + + WriteMessageBoxedWithInvertedLabels(message, logLevel, category, scopeProvider, textWriter); + + if (logEntry.Exception is not null) + { + textWriter.Write($" Exception Details: {logEntry.Exception}"); + } + + textWriter.WriteLine(); + } + + private void WriteMessageBoxedWithInvertedLabels(string? message, LogLevel logLevel, string? category, IExternalScopeProvider? scopeProvider, TextWriter textWriter) + { + if (message is null) + { + return; + } + + var label = GetLogLevelString(logLevel); + var (bgColor, fgColor) = GetLogLevelColor(logLevel); + + textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); + textWriter.Write($"{labelSpacing}{_boxSpacing}{(logLevel == LogLevel.Debug ? $"[{DateTime.Now:T}] " : "")}"); + + if (_options.IncludeScopes && scopeProvider is not null) + { + scopeProvider.ForEachScope((scope, state) => + { + if (scope is null) + { + return; + } + + if (scope is string scopeString) + { + textWriter.Write(scopeString); + textWriter.Write(": "); + } + else if (scope.GetType().Name == "FormattedLogValues") + { + textWriter.Write(scope.ToString()); + textWriter.Write(": "); + } + }, textWriter); + } + + if (!string.IsNullOrEmpty(category)) + { + textWriter.Write($"{category}: "); + } + + textWriter.Write(message); + } + + private void WriteLogMessageBoxedWithInvertedLabels(RequestLog log, IExternalScopeProvider? scopeProvider, TextWriter textWriter, bool lastMessage = false) + { + var label = GetMessageTypeString(log.MessageType); + var (bgColor, fgColor) = GetMessageTypeColor(log.MessageType); + + void writePluginName() + { + if (_options.IncludeScopes && log.PluginName is not null) + { + textWriter.Write($"{log.PluginName}: "); + } + } + + switch (log.MessageType) + { + case MessageType.InterceptedRequest: + // always one line (method + URL) + // print label and top of the box + textWriter.WriteLine(); + textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); + textWriter.Write($"{(label.Length < 4 ? " " : "")}{labelSpacing}{_boxTopLeft}"); + writePluginName(); + textWriter.WriteLine(log.Message); + break; + default: + textWriter.WriteColoredMessage($" {label} ", bgColor, fgColor); + textWriter.Write($"{(label.Length < 4 ? " " : "")}{labelSpacing}{(lastMessage ? _boxBottomLeft : _boxLeft)}"); + writePluginName(); + textWriter.WriteLine(log.Message); + break; + } + } + + // from https://github.com/dotnet/runtime/blob/198a2596229f69b8e02902bfb4ffc2a30be3b339/src/libraries/Microsoft.Extensions.Logging.Console/src/SimpleConsoleFormatter.cs#L154 + private static string GetLogLevelString(LogLevel logLevel) + { + return logLevel switch + { + LogLevel.Trace => "trce", + LogLevel.Debug => "dbug", + LogLevel.Information => "info", + LogLevel.Warning => "warn", + LogLevel.Error => "fail", + LogLevel.Critical => "crit", + _ => throw new ArgumentOutOfRangeException(nameof(logLevel)) + }; + } + + private static (ConsoleColor bg, ConsoleColor fg) GetLogLevelColor(LogLevel logLevel) + { + var fgColor = Console.ForegroundColor; + var bgColor = Console.BackgroundColor; + + return logLevel switch + { + LogLevel.Information => (bgColor, ConsoleColor.Blue), + LogLevel.Warning => (ConsoleColor.DarkYellow, fgColor), + LogLevel.Error => (ConsoleColor.DarkRed, fgColor), + LogLevel.Debug => (bgColor, ConsoleColor.Gray), + LogLevel.Trace => (bgColor, ConsoleColor.Gray), + _ => (bgColor, fgColor) + }; + } + + private static string GetMessageTypeString(MessageType messageType) + { + return messageType switch + { + MessageType.InterceptedRequest => "req", + MessageType.InterceptedResponse => "res", + MessageType.PassedThrough => "api", + MessageType.Chaos => "oops", + MessageType.Warning => "warn", + MessageType.Mocked => "mock", + MessageType.Failed => "fail", + MessageType.Tip => "tip", + MessageType.Skipped => "skip", + _ => " " + }; + } + + private static (ConsoleColor bg, ConsoleColor fg) GetMessageTypeColor(MessageType messageType) + { + var fgColor = Console.ForegroundColor; + var bgColor = Console.BackgroundColor; + + return messageType switch + { + MessageType.InterceptedRequest => (bgColor, ConsoleColor.Gray), + MessageType.PassedThrough => (ConsoleColor.Gray, fgColor), + MessageType.Skipped => (bgColor, ConsoleColor.Gray), + MessageType.Chaos => (ConsoleColor.DarkRed, fgColor), + MessageType.Warning => (ConsoleColor.DarkYellow, fgColor), + MessageType.Mocked => (ConsoleColor.DarkMagenta, fgColor), + MessageType.Failed => (ConsoleColor.DarkRed, fgColor), + MessageType.Tip => (ConsoleColor.DarkBlue, fgColor), + _ => (bgColor, fgColor) + }; + } } \ No newline at end of file diff --git a/dev-proxy/Logging/ProxyConsoleFormatterOptions.cs b/dev-proxy/Logging/ProxyConsoleFormatterOptions.cs new file mode 100644 index 00000000..066c83a5 --- /dev/null +++ b/dev-proxy/Logging/ProxyConsoleFormatterOptions.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Logging.Console; + +namespace Microsoft.DevProxy.Logging; + +public class ProxyConsoleFormatterOptions: ConsoleFormatterOptions +{ + public bool ShowSkipMessages { get; set; } = true; +} \ No newline at end of file diff --git a/dev-proxy/PluginLoader.cs b/dev-proxy/PluginLoader.cs index 82a737dd..4d25dd9d 100644 --- a/dev-proxy/PluginLoader.cs +++ b/dev-proxy/PluginLoader.cs @@ -1,163 +1,165 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License - -using Microsoft.DevProxy.Abstractions; -using System.Reflection; -using System.Text.RegularExpressions; - -namespace Microsoft.DevProxy; - -internal class PluginLoaderResult(ISet urlsToWatch, IEnumerable proxyPlugins) -{ - public ISet UrlsToWatch { get; } = urlsToWatch ?? throw new ArgumentNullException(nameof(urlsToWatch)); - public IEnumerable ProxyPlugins { get; } = proxyPlugins ?? throw new ArgumentNullException(nameof(proxyPlugins)); -} - -internal class PluginLoader(ILogger logger) -{ - private PluginConfig? _pluginConfig; - private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - - public async Task LoadPluginsAsync(IPluginEvents pluginEvents, IProxyContext proxyContext) - { - List plugins = []; - var config = PluginConfig; - var globallyWatchedUrls = PluginConfig.UrlsToWatch.Select(ConvertToRegex).ToList(); - var defaultUrlsToWatch = globallyWatchedUrls.ToHashSet(); - var configFileDirectory = Path.GetDirectoryName(Path.GetFullPath(ProxyUtils.ReplacePathTokens(ProxyHost.ConfigFile))); - // key = location - var pluginContexts = new Dictionary(); - - if (!string.IsNullOrEmpty(configFileDirectory)) - { - foreach (PluginReference h in config.Plugins) - { - if (!h.Enabled) continue; - // Load Handler Assembly if enabled - var pluginLocation = Path.GetFullPath(Path.Combine(configFileDirectory, ProxyUtils.ReplacePathTokens(h.PluginPath.Replace('\\', Path.DirectorySeparatorChar)))); - - if (!pluginContexts.TryGetValue(pluginLocation, out PluginLoadContext? pluginLoadContext)) - { - pluginLoadContext = new PluginLoadContext(pluginLocation); - pluginContexts.Add(pluginLocation, pluginLoadContext); - } - - _logger?.LogDebug("Loading plugin {pluginName} from: {pluginLocation}", h.Name, pluginLocation); - var assembly = pluginLoadContext.LoadFromAssemblyName(new AssemblyName(Path.GetFileNameWithoutExtension(pluginLocation))); - var pluginUrlsList = h.UrlsToWatch?.Select(ConvertToRegex); - ISet? pluginUrls = null; - - if (pluginUrlsList is not null) - { - pluginUrls = pluginUrlsList.ToHashSet(); - globallyWatchedUrls.AddRange(pluginUrlsList); - } - - var plugin = CreatePlugin( - assembly, - h, - pluginEvents, - proxyContext, - (pluginUrls != null && pluginUrls.Any()) ? pluginUrls : defaultUrlsToWatch, - h.ConfigSection is null ? null : Configuration.GetSection(h.ConfigSection) - ); - _logger?.LogDebug("Registering plugin {pluginName}...", plugin.Name); - await plugin.RegisterAsync(); - _logger?.LogDebug("Plugin {pluginName} registered.", plugin.Name); - plugins.Add(plugin); - } - } - - return plugins.Count > 0 - ? new PluginLoaderResult(globallyWatchedUrls.ToHashSet(), plugins) - : throw new InvalidDataException("No plugins were loaded"); - } - - private IProxyPlugin CreatePlugin( - Assembly assembly, - PluginReference pluginReference, - IPluginEvents pluginEvents, - IProxyContext context, - ISet urlsToWatch, - IConfigurationSection? configSection = null - ) - { - foreach (Type type in assembly.GetTypes()) - { - if (type.Name == pluginReference.Name && - typeof(IProxyPlugin).IsAssignableFrom(type)) - { - IProxyPlugin? result = Activator.CreateInstance(type, [pluginEvents, context, _logger, urlsToWatch, configSection]) as IProxyPlugin; - if (result is not null && result.Name == pluginReference.Name) - { - return result; - } - } - } - - string availableTypes = string.Join(",", assembly.GetTypes().Select(t => t.FullName)); - throw new ApplicationException( - $"Can't find plugin {pluginReference.Name} which implements IProxyPlugin in {assembly} from {AppContext.BaseDirectory}.\r\n" + - $"Available types: {availableTypes}"); - } - - public static UrlToWatch ConvertToRegex(string stringMatcher) - { - var exclude = false; - if (stringMatcher.StartsWith('!')) - { - exclude = true; - stringMatcher = stringMatcher[1..]; - } - - return new UrlToWatch( - new Regex($"^{Regex.Escape(stringMatcher).Replace("\\*", ".*")}$", RegexOptions.Compiled | RegexOptions.IgnoreCase), - exclude - ); - } - - private PluginConfig PluginConfig - { - get - { - if (_pluginConfig == null) - { - _pluginConfig = new PluginConfig(); - Configuration.Bind(_pluginConfig); - - if (ProxyHost.UrlsToWatch is not null && ProxyHost.UrlsToWatch.Any()) - { - _pluginConfig.UrlsToWatch = ProxyHost.UrlsToWatch.ToList(); - } - } - if (_pluginConfig == null || _pluginConfig.Plugins.Count == 0) - { - throw new InvalidDataException("The configuration must contain at least one plugin"); - } - return _pluginConfig; - } - } - - private IConfigurationRoot Configuration { get => ConfigurationFactory.Value; } - - private readonly Lazy ConfigurationFactory = new(() => - new ConfigurationBuilder() - .AddJsonFile(ProxyHost.ConfigFile, optional: true, reloadOnChange: true) - .Build() - ); -} - -internal class PluginConfig -{ - public List Plugins { get; set; } = []; - public List UrlsToWatch { get; set; } = []; -} - -internal class PluginReference -{ - public bool Enabled { get; set; } = true; - public string? ConfigSection { get; set; } - public string PluginPath { get; set; } = string.Empty; - public string Name { get; set; } = string.Empty; - public List? UrlsToWatch { get; set; } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License + +using Microsoft.DevProxy.Abstractions; +using System.Reflection; +using System.Text.RegularExpressions; + +namespace Microsoft.DevProxy; + +internal class PluginLoaderResult(ISet urlsToWatch, IEnumerable proxyPlugins) +{ + public ISet UrlsToWatch { get; } = urlsToWatch ?? throw new ArgumentNullException(nameof(urlsToWatch)); + public IEnumerable ProxyPlugins { get; } = proxyPlugins ?? throw new ArgumentNullException(nameof(proxyPlugins)); +} + +internal class PluginLoader(ILogger logger, ILoggerFactory loggerFactory) +{ + private PluginConfig? _pluginConfig; + private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + private readonly ILoggerFactory _loggerFactory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory)); + + public async Task LoadPluginsAsync(IPluginEvents pluginEvents, IProxyContext proxyContext) + { + List plugins = []; + var config = PluginConfig; + var globallyWatchedUrls = PluginConfig.UrlsToWatch.Select(ConvertToRegex).ToList(); + var defaultUrlsToWatch = globallyWatchedUrls.ToHashSet(); + var configFileDirectory = Path.GetDirectoryName(Path.GetFullPath(ProxyUtils.ReplacePathTokens(ProxyHost.ConfigFile))); + // key = location + var pluginContexts = new Dictionary(); + + if (!string.IsNullOrEmpty(configFileDirectory)) + { + foreach (PluginReference h in config.Plugins) + { + if (!h.Enabled) continue; + // Load Handler Assembly if enabled + var pluginLocation = Path.GetFullPath(Path.Combine(configFileDirectory, ProxyUtils.ReplacePathTokens(h.PluginPath.Replace('\\', Path.DirectorySeparatorChar)))); + + if (!pluginContexts.TryGetValue(pluginLocation, out PluginLoadContext? pluginLoadContext)) + { + pluginLoadContext = new PluginLoadContext(pluginLocation); + pluginContexts.Add(pluginLocation, pluginLoadContext); + } + + _logger?.LogDebug("Loading plugin {pluginName} from: {pluginLocation}", h.Name, pluginLocation); + var assembly = pluginLoadContext.LoadFromAssemblyName(new AssemblyName(Path.GetFileNameWithoutExtension(pluginLocation))); + var pluginUrlsList = h.UrlsToWatch?.Select(ConvertToRegex); + ISet? pluginUrls = null; + + if (pluginUrlsList is not null) + { + pluginUrls = pluginUrlsList.ToHashSet(); + globallyWatchedUrls.AddRange(pluginUrlsList); + } + + var plugin = CreatePlugin( + assembly, + h, + pluginEvents, + proxyContext, + (pluginUrls != null && pluginUrls.Any()) ? pluginUrls : defaultUrlsToWatch, + h.ConfigSection is null ? null : Configuration.GetSection(h.ConfigSection) + ); + _logger?.LogDebug("Registering plugin {pluginName}...", plugin.Name); + await plugin.RegisterAsync(); + _logger?.LogDebug("Plugin {pluginName} registered.", plugin.Name); + plugins.Add(plugin); + } + } + + return plugins.Count > 0 + ? new PluginLoaderResult(globallyWatchedUrls.ToHashSet(), plugins) + : throw new InvalidDataException("No plugins were loaded"); + } + + private IProxyPlugin CreatePlugin( + Assembly assembly, + PluginReference pluginReference, + IPluginEvents pluginEvents, + IProxyContext context, + ISet urlsToWatch, + IConfigurationSection? configSection = null + ) + { + foreach (Type type in assembly.GetTypes()) + { + if (type.Name == pluginReference.Name && + typeof(IProxyPlugin).IsAssignableFrom(type)) + { + var logger = _loggerFactory.CreateLogger(type.Name); + IProxyPlugin? result = Activator.CreateInstance(type, [pluginEvents, context, logger, urlsToWatch, configSection]) as IProxyPlugin; + if (result is not null && result.Name == pluginReference.Name) + { + return result; + } + } + } + + string availableTypes = string.Join(",", assembly.GetTypes().Select(t => t.FullName)); + throw new ApplicationException( + $"Can't find plugin {pluginReference.Name} which implements IProxyPlugin in {assembly} from {AppContext.BaseDirectory}.\r\n" + + $"Available types: {availableTypes}"); + } + + public static UrlToWatch ConvertToRegex(string stringMatcher) + { + var exclude = false; + if (stringMatcher.StartsWith('!')) + { + exclude = true; + stringMatcher = stringMatcher[1..]; + } + + return new UrlToWatch( + new Regex($"^{Regex.Escape(stringMatcher).Replace("\\*", ".*")}$", RegexOptions.Compiled | RegexOptions.IgnoreCase), + exclude + ); + } + + private PluginConfig PluginConfig + { + get + { + if (_pluginConfig == null) + { + _pluginConfig = new PluginConfig(); + Configuration.Bind(_pluginConfig); + + if (ProxyHost.UrlsToWatch is not null && ProxyHost.UrlsToWatch.Any()) + { + _pluginConfig.UrlsToWatch = ProxyHost.UrlsToWatch.ToList(); + } + } + if (_pluginConfig == null || _pluginConfig.Plugins.Count == 0) + { + throw new InvalidDataException("The configuration must contain at least one plugin"); + } + return _pluginConfig; + } + } + + private IConfigurationRoot Configuration { get => ConfigurationFactory.Value; } + + private readonly Lazy ConfigurationFactory = new(() => + new ConfigurationBuilder() + .AddJsonFile(ProxyHost.ConfigFile, optional: true, reloadOnChange: true) + .Build() + ); +} + +internal class PluginConfig +{ + public List Plugins { get; set; } = []; + public List UrlsToWatch { get; set; } = []; +} + +internal class PluginReference +{ + public bool Enabled { get; set; } = true; + public string? ConfigSection { get; set; } + public string PluginPath { get; set; } = string.Empty; + public string Name { get; set; } = string.Empty; + public List? UrlsToWatch { get; set; } +} diff --git a/dev-proxy/Program.cs b/dev-proxy/Program.cs index 9717e02a..c9107943 100644 --- a/dev-proxy/Program.cs +++ b/dev-proxy/Program.cs @@ -1,118 +1,119 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.DevProxy; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Abstractions.LanguageModel; -using Microsoft.DevProxy.CommandHandlers; -using Microsoft.DevProxy.Logging; -using Microsoft.Extensions.Logging.Console; -using System.CommandLine; - -_ = Announcement.ShowAsync(); - -PluginEvents pluginEvents = new(); - -ILogger BuildLogger() -{ - var loggerFactory = LoggerFactory.Create(builder => - { - builder - .AddConsole(options => - { - options.FormatterName = "devproxy"; - options.LogToStandardErrorThreshold = LogLevel.Warning; - }) - .AddConsoleFormatter(options => { - options.IncludeScopes = true; - }) - .AddRequestLogger(pluginEvents) - .SetMinimumLevel(ProxyHost.LogLevel ?? ProxyCommandHandler.Configuration.LogLevel); - }); - return loggerFactory.CreateLogger("devproxy"); -} - -var logger = BuildLogger(); - -var lmClient = new OllamaLanguageModelClient(ProxyCommandHandler.Configuration.LanguageModel, logger); -IProxyContext context = new ProxyContext(ProxyCommandHandler.Configuration, ProxyEngine.Certificate, lmClient); -ProxyHost proxyHost = new(); - -// this is where the root command is created which contains all commands and subcommands -RootCommand rootCommand = proxyHost.GetRootCommand(logger); - -// store the global options that are created automatically for us -// rootCommand doesn't return the global options, so we have to store them manually -string[] globalOptions = ["--version", "--help", "-h", "/h", "-?", "/?"]; - -// check if any of the global options are present -var hasGlobalOption = args.Any(arg => globalOptions.Contains(arg)); - -// get the list of available subcommands -var subCommands = rootCommand.Children.OfType().Select(c => c.Name).ToArray(); - -// check if any of the subcommands are present -var hasSubCommand = args.Any(arg => subCommands.Contains(arg)); - -if (hasGlobalOption || hasSubCommand) -{ - // we don't need to load plugins if the user is using a global option or using a subcommand, so we can exit early - await rootCommand.InvokeAsync(args); - return; -} - -var pluginLoader = new PluginLoader(logger); -PluginLoaderResult loaderResults = await pluginLoader.LoadPluginsAsync(pluginEvents, context); -// have all the plugins init -pluginEvents.RaiseInit(new InitArgs()); - -var options = loaderResults.ProxyPlugins - .SelectMany(p => p.GetOptions()) - // remove duplicates by comparing the option names - .GroupBy(o => o.Name) - .Select(g => g.First()) - .ToList(); -options.ForEach(rootCommand.AddOption); -// register all plugin commands -loaderResults.ProxyPlugins - .SelectMany(p => p.GetCommands()) - .ToList() - .ForEach(rootCommand.AddCommand); - -rootCommand.Handler = proxyHost.GetCommandHandler(pluginEvents, [.. options], loaderResults.UrlsToWatch, logger); - -// filter args to retrieve options -var incomingOptions = args.Where(arg => arg.StartsWith('-')).ToArray(); - -// remove the global options from the incoming options -incomingOptions = incomingOptions.Except(globalOptions).ToArray(); - -// compare the incoming options against the root command options -foreach (var option in rootCommand.Options) -{ - // get the option aliases - var aliases = option.Aliases.ToArray(); - - // iterate over aliases - foreach (string alias in aliases) - { - // if the alias is present - if (incomingOptions.Contains(alias)) - { - // remove the option from the incoming options - incomingOptions = incomingOptions.Where(val => val != alias).ToArray(); - } - } -} - -// list the remaining incoming options as unknown in the output -if (incomingOptions.Length > 0) -{ - logger.LogError("Unknown option(s): {unknownOptions}", string.Join(" ", incomingOptions)); - logger.LogInformation("TIP: Use --help view available options"); - logger.LogInformation("TIP: Are you missing a plugin? See: https://aka.ms/devproxy/plugins"); -} -else -{ - await rootCommand.InvokeAsync(args); -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DevProxy; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Abstractions.LanguageModel; +using Microsoft.DevProxy.CommandHandlers; +using Microsoft.DevProxy.Logging; +using Microsoft.Extensions.Logging.Console; +using System.CommandLine; + +_ = Announcement.ShowAsync(); + +PluginEvents pluginEvents = new(); + +(ILogger, ILoggerFactory) BuildLogger() +{ + var loggerFactory = LoggerFactory.Create(builder => + { + builder + .AddConsole(options => + { + options.FormatterName = ProxyConsoleFormatter.DefaultCategoryName; + options.LogToStandardErrorThreshold = LogLevel.Warning; + }) + .AddConsoleFormatter(options => { + options.IncludeScopes = true; + options.ShowSkipMessages = ProxyCommandHandler.Configuration.ShowSkipMessages; + }) + .AddRequestLogger(pluginEvents) + .SetMinimumLevel(ProxyHost.LogLevel ?? ProxyCommandHandler.Configuration.LogLevel); + }); + return (loggerFactory.CreateLogger(ProxyConsoleFormatter.DefaultCategoryName), loggerFactory); +} + +var (logger, loggerFactory) = BuildLogger(); + +var lmClient = new OllamaLanguageModelClient(ProxyCommandHandler.Configuration.LanguageModel, logger); +IProxyContext context = new ProxyContext(ProxyCommandHandler.Configuration, ProxyEngine.Certificate, lmClient); +ProxyHost proxyHost = new(); + +// this is where the root command is created which contains all commands and subcommands +RootCommand rootCommand = proxyHost.GetRootCommand(logger); + +// store the global options that are created automatically for us +// rootCommand doesn't return the global options, so we have to store them manually +string[] globalOptions = ["--version", "--help", "-h", "/h", "-?", "/?"]; + +// check if any of the global options are present +var hasGlobalOption = args.Any(arg => globalOptions.Contains(arg)); + +// get the list of available subcommands +var subCommands = rootCommand.Children.OfType().Select(c => c.Name).ToArray(); + +// check if any of the subcommands are present +var hasSubCommand = args.Any(arg => subCommands.Contains(arg)); + +if (hasGlobalOption || hasSubCommand) +{ + // we don't need to load plugins if the user is using a global option or using a subcommand, so we can exit early + await rootCommand.InvokeAsync(args); + return; +} + +var pluginLoader = new PluginLoader(logger, loggerFactory); +PluginLoaderResult loaderResults = await pluginLoader.LoadPluginsAsync(pluginEvents, context); +// have all the plugins init +pluginEvents.RaiseInit(new InitArgs()); + +var options = loaderResults.ProxyPlugins + .SelectMany(p => p.GetOptions()) + // remove duplicates by comparing the option names + .GroupBy(o => o.Name) + .Select(g => g.First()) + .ToList(); +options.ForEach(rootCommand.AddOption); +// register all plugin commands +loaderResults.ProxyPlugins + .SelectMany(p => p.GetCommands()) + .ToList() + .ForEach(rootCommand.AddCommand); + +rootCommand.Handler = proxyHost.GetCommandHandler(pluginEvents, [.. options], loaderResults.UrlsToWatch, logger); + +// filter args to retrieve options +var incomingOptions = args.Where(arg => arg.StartsWith('-')).ToArray(); + +// remove the global options from the incoming options +incomingOptions = incomingOptions.Except(globalOptions).ToArray(); + +// compare the incoming options against the root command options +foreach (var option in rootCommand.Options) +{ + // get the option aliases + var aliases = option.Aliases.ToArray(); + + // iterate over aliases + foreach (string alias in aliases) + { + // if the alias is present + if (incomingOptions.Contains(alias)) + { + // remove the option from the incoming options + incomingOptions = incomingOptions.Where(val => val != alias).ToArray(); + } + } +} + +// list the remaining incoming options as unknown in the output +if (incomingOptions.Length > 0) +{ + logger.LogError("Unknown option(s): {unknownOptions}", string.Join(" ", incomingOptions)); + logger.LogInformation("TIP: Use --help view available options"); + logger.LogInformation("TIP: Are you missing a plugin? See: https://aka.ms/devproxy/plugins"); +} +else +{ + await rootCommand.InvokeAsync(args); +} diff --git a/dev-proxy/ProxyConfiguration.cs b/dev-proxy/ProxyConfiguration.cs index 213e25f4..f72c2c64 100755 --- a/dev-proxy/ProxyConfiguration.cs +++ b/dev-proxy/ProxyConfiguration.cs @@ -1,40 +1,41 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Runtime.Serialization; -using System.Text.Json.Serialization; -using Microsoft.DevProxy.Abstractions; -using Microsoft.DevProxy.Abstractions.LanguageModel; - -namespace Microsoft.DevProxy; - -public enum ReleaseType -{ - [EnumMember(Value = "none")] - None, - [EnumMember(Value = "stable")] - Stable, - [EnumMember(Value = "beta")] - Beta -} - -public class ProxyConfiguration : IProxyConfiguration -{ - public int Port { get; set; } = 8000; - public string? IPAddress { get; set; } = "127.0.0.1"; - public bool Record { get; set; } = false; - [JsonConverter(typeof(JsonStringEnumConverter))] - public LogLevel LogLevel { get; set; } = LogLevel.Information; - public IEnumerable WatchPids { get; set; } = new List(); - public IEnumerable WatchProcessNames { get; set; } = []; - public int Rate { get; set; } = 50; - public bool NoFirstRun { get; set; } = false; - public bool AsSystemProxy { get; set; } = true; - public bool InstallCert { get; set; } = true; - public string ConfigFile { get; set; } = "devproxyrc.json"; - [JsonConverter(typeof(JsonStringEnumConverter))] - public ReleaseType NewVersionNotification { get; set; } = ReleaseType.Stable; - public LanguageModelConfiguration? LanguageModel { get; set; } - public MockRequestHeader[]? FilterByHeaders { get; set; } - public int ApiPort { get; set; } = 8897; -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.Serialization; +using System.Text.Json.Serialization; +using Microsoft.DevProxy.Abstractions; +using Microsoft.DevProxy.Abstractions.LanguageModel; + +namespace Microsoft.DevProxy; + +public enum ReleaseType +{ + [EnumMember(Value = "none")] + None, + [EnumMember(Value = "stable")] + Stable, + [EnumMember(Value = "beta")] + Beta +} + +public class ProxyConfiguration : IProxyConfiguration +{ + public int Port { get; set; } = 8000; + public string? IPAddress { get; set; } = "127.0.0.1"; + public bool Record { get; set; } = false; + [JsonConverter(typeof(JsonStringEnumConverter))] + public LogLevel LogLevel { get; set; } = LogLevel.Information; + public IEnumerable WatchPids { get; set; } = new List(); + public IEnumerable WatchProcessNames { get; set; } = []; + public int Rate { get; set; } = 50; + public bool NoFirstRun { get; set; } = false; + public bool AsSystemProxy { get; set; } = true; + public bool InstallCert { get; set; } = true; + public string ConfigFile { get; set; } = "devproxyrc.json"; + [JsonConverter(typeof(JsonStringEnumConverter))] + public ReleaseType NewVersionNotification { get; set; } = ReleaseType.Stable; + public LanguageModelConfiguration? LanguageModel { get; set; } + public MockRequestHeader[]? FilterByHeaders { get; set; } + public int ApiPort { get; set; } = 8897; + public bool ShowSkipMessages { get; set; } = true; +} diff --git a/dev-proxy/ProxyEngine.cs b/dev-proxy/ProxyEngine.cs index 11daa7da..2703a90c 100755 --- a/dev-proxy/ProxyEngine.cs +++ b/dev-proxy/ProxyEngine.cs @@ -1,616 +1,616 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using Microsoft.DevProxy.Abstractions; -using Microsoft.VisualStudio.Threading; -using System.Diagnostics; -using System.Net; -using System.Security.Cryptography.X509Certificates; -using System.Text.RegularExpressions; -using Titanium.Web.Proxy; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Helpers; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; - -namespace Microsoft.DevProxy; - -enum ToggleSystemProxyAction -{ - On, - Off -} - -public class ProxyEngine(IProxyConfiguration config, ISet urlsToWatch, IPluginEvents pluginEvents, IProxyState proxyState, ILogger logger) : BackgroundService -{ - private readonly IPluginEvents _pluginEvents = pluginEvents ?? throw new ArgumentNullException(nameof(pluginEvents)); - private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - private readonly IProxyConfiguration _config = config ?? throw new ArgumentNullException(nameof(config)); - private static readonly ProxyServer? _proxyServer; - private ExplicitProxyEndPoint? _explicitEndPoint; - // lists of URLs to watch, used for intercepting requests - private readonly ISet _urlsToWatch = urlsToWatch ?? throw new ArgumentNullException(nameof(urlsToWatch)); - // lists of hosts to watch extracted from urlsToWatch, - // used for deciding which URLs to decrypt for further inspection - private readonly ISet _hostsToWatch = new HashSet(); - private readonly IProxyState _proxyState = proxyState ?? throw new ArgumentNullException(nameof(proxyState)); - // Dictionary for plugins to store data between requests - // the key is HashObject of the SessionEventArgs object - private readonly Dictionary> _pluginData = []; - - public static X509Certificate2? Certificate => _proxyServer?.CertificateManager.RootCertificate; - - private ExceptionHandler ExceptionHandler => ex => _logger.LogError(ex, "An error occurred in a plugin"); - - static ProxyEngine() - { - _proxyServer = new ProxyServer(); - _proxyServer.CertificateManager.RootCertificateName = "Dev Proxy CA"; - _proxyServer.CertificateManager.CertificateStorage = new CertificateDiskCache(); - // we need to change this to a value lower than 397 - // to avoid the ERR_CERT_VALIDITY_TOO_LONG error in Edge - _proxyServer.CertificateManager.CertificateValidDays = 365; - - var joinableTaskContext = new JoinableTaskContext(); - var joinableTaskFactory = new JoinableTaskFactory(joinableTaskContext); - _ = joinableTaskFactory.Run(async () => await _proxyServer.CertificateManager.LoadOrCreateRootCertificateAsync()); - } - - private static void ToggleSystemProxy(ToggleSystemProxyAction toggle, string? ipAddress = null, int? port = null) - { - var bashScriptPath = Path.Join(ProxyUtils.AppFolder, "toggle-proxy.sh"); - var args = toggle switch - { - ToggleSystemProxyAction.On => $"on {ipAddress} {port}", - ToggleSystemProxyAction.Off => "off", - _ => throw new NotImplementedException() - }; - - ProcessStartInfo startInfo = new ProcessStartInfo() - { - FileName = "/bin/bash", - Arguments = $"{bashScriptPath} {args}", - RedirectStandardOutput = true, - UseShellExecute = false, - CreateNoWindow = true - }; - - var process = new Process() { StartInfo = startInfo }; - process.Start(); - process.WaitForExit(); - } - - protected override async Task ExecuteAsync(CancellationToken stoppingToken) - { - Debug.Assert(_proxyServer is not null, "Proxy server is not initialized"); - - if (!_urlsToWatch.Any()) - { - _logger.LogInformation("No URLs to watch configured. Please add URLs to watch in the devproxyrc.json config file."); - return; - } - - LoadHostNamesFromUrls(); - - _proxyServer.BeforeRequest += OnRequestAsync; - _proxyServer.BeforeResponse += OnBeforeResponseAsync; - _proxyServer.AfterResponse += OnAfterResponseAsync; - _proxyServer.ServerCertificateValidationCallback += OnCertificateValidationAsync; - _proxyServer.ClientCertificateSelectionCallback += OnCertificateSelectionAsync; - - var ipAddress = string.IsNullOrEmpty(_config.IPAddress) ? IPAddress.Any : IPAddress.Parse(_config.IPAddress); - _explicitEndPoint = new ExplicitProxyEndPoint(ipAddress, _config.Port, true); - // Fired when a CONNECT request is received - _explicitEndPoint.BeforeTunnelConnectRequest += OnBeforeTunnelConnectRequestAsync; - if (_config.InstallCert) - { - await _proxyServer.CertificateManager.EnsureRootCertificateAsync(stoppingToken); - } - else - { - _explicitEndPoint.GenericCertificate = await _proxyServer - .CertificateManager - .LoadRootCertificateAsync(stoppingToken); - } - - _proxyServer.AddEndPoint(_explicitEndPoint); - await _proxyServer.StartAsync(cancellationToken: stoppingToken); - - // run first-run setup on macOS - FirstRunSetup(); - - foreach (var endPoint in _proxyServer.ProxyEndPoints) - { - _logger.LogInformation("Dev Proxy listening on {ipAddress}:{port}...", endPoint.IpAddress, endPoint.Port); - } - - if (_config.AsSystemProxy) - { - if (RunTime.IsWindows) - { - _proxyServer.SetAsSystemHttpProxy(_explicitEndPoint); - _proxyServer.SetAsSystemHttpsProxy(_explicitEndPoint); - } - else if (RunTime.IsMac) - { - ToggleSystemProxy(ToggleSystemProxyAction.On, _config.IPAddress, _config.Port); - } - else - { - _logger.LogWarning("Configure your operating system to use this proxy's port and address {ipAddress}:{port}", _config.IPAddress, _config.Port); - } - } - else - { - _logger.LogInformation("Configure your application to use this proxy's port and address"); - } - - var isInteractive = !Console.IsInputRedirected && - Environment.GetEnvironmentVariable("CI") is null; - - if (isInteractive) - { - // only print hotkeys when they can be used - PrintHotkeys(); - } - - if (_config.Record) - { - StartRecording(); - } - _pluginEvents.AfterRequestLog += AfterRequestLogAsync; - - while (!stoppingToken.IsCancellationRequested && _proxyServer.ProxyRunning) - { - while (!Console.KeyAvailable) - { - await Task.Delay(10, stoppingToken); - } - // we need this check or proxy will fail with an exception - // when run for example in VSCode's integrated terminal - if (isInteractive) - { - await ReadKeysAsync(); - } - } - } - - private void FirstRunSetup() - { - if (!RunTime.IsMac || - _config.NoFirstRun || - !IsFirstRun() || - !_config.InstallCert) - { - return; - } - - var bashScriptPath = Path.Join(ProxyUtils.AppFolder, "trust-cert.sh"); - ProcessStartInfo startInfo = new() - { - FileName = "/bin/bash", - Arguments = bashScriptPath, - UseShellExecute = true, - CreateNoWindow = false - }; - - var process = new Process() { StartInfo = startInfo }; - process.Start(); - process.WaitForExit(); - } - - private static bool IsFirstRun() - { - var firstRunFilePath = Path.Combine(ProxyUtils.AppFolder!, ".hasrun"); - if (File.Exists(firstRunFilePath)) - { - return false; - } - - try - { - File.WriteAllText(firstRunFilePath, ""); - } - catch { } - - return true; - } - - private Task AfterRequestLogAsync(object? sender, RequestLogArgs e) - { - if (!_proxyState.IsRecording) - { - return Task.CompletedTask; - } - - _proxyState.RequestLogs.Add(e.RequestLog); - return Task.CompletedTask; - } - - private async Task ReadKeysAsync() - { - var key = Console.ReadKey(true).Key; - switch (key) - { - case ConsoleKey.R: - StartRecording(); - break; - case ConsoleKey.S: - await StopRecordingAsync(); - break; - case ConsoleKey.C: - Console.Clear(); - PrintHotkeys(); - break; - case ConsoleKey.W: - await _proxyState.RaiseMockRequestAsync(); - break; - } - } - - private void StartRecording() - { - if (_proxyState.IsRecording) - { - return; - } - - _proxyState.StartRecording(); - } - - private async Task StopRecordingAsync() - { - if (!_proxyState.IsRecording) - { - return; - } - - await _proxyState.StopRecordingAsync(); - } - - // Convert strings from config to regexes. - // From the list of URLs, extract host names and convert them to regexes. - // We need this because before we decrypt a request, we only have access - // to the host name, not the full URL. - private void LoadHostNamesFromUrls() - { - foreach (var urlToWatch in _urlsToWatch) - { - // extract host from the URL - string urlToWatchPattern = Regex.Unescape(urlToWatch.Url.ToString()).Replace(".*", "*"); - string hostToWatch; - if (urlToWatchPattern.ToString().Contains("://")) - { - // if the URL contains a protocol, extract the host from the URL - var urlChunks = urlToWatchPattern.Split("://"); - var slashPos = urlChunks[1].IndexOf('/'); - hostToWatch = slashPos < 0 ? urlChunks[1] : urlChunks[1][..slashPos]; - } - else - { - // if the URL doesn't contain a protocol, - // we assume the whole URL is a host name - hostToWatch = urlToWatchPattern; - } - - // remove port number if present - var portPos = hostToWatch.IndexOf(':'); - if (portPos > 0) - { - hostToWatch = hostToWatch[..portPos]; - } - - var hostToWatchRegexString = Regex.Escape(hostToWatch).Replace("\\*", ".*"); - Regex hostRegex = new($"^{hostToWatchRegexString}$", RegexOptions.Compiled | RegexOptions.IgnoreCase); - // don't add the same host twice - if (!_hostsToWatch.Any(h => h.Url.ToString() == hostRegex.ToString())) - { - _hostsToWatch.Add(new UrlToWatch(hostRegex)); - } - } - } - - private void StopProxy() - { - // Unsubscribe & Quit - try - { - if (_explicitEndPoint != null) - { - _explicitEndPoint.BeforeTunnelConnectRequest -= OnBeforeTunnelConnectRequestAsync; - } - - if (_proxyServer is not null) - { - _proxyServer.BeforeRequest -= OnRequestAsync; - _proxyServer.BeforeResponse -= OnBeforeResponseAsync; - _proxyServer.AfterResponse -= OnAfterResponseAsync; - _proxyServer.ServerCertificateValidationCallback -= OnCertificateValidationAsync; - _proxyServer.ClientCertificateSelectionCallback -= OnCertificateSelectionAsync; - - _proxyServer.Stop(); - } - - if (RunTime.IsMac && _config.AsSystemProxy) - { - ToggleSystemProxy(ToggleSystemProxyAction.Off); - } - } - catch (Exception ex) - { - _logger.LogError(ex, "An error occurred while stopping the proxy"); - } - } - - public override async Task StopAsync(CancellationToken cancellationToken) - { - await StopRecordingAsync(); - StopProxy(); - - await base.StopAsync(cancellationToken); - } - - async Task OnBeforeTunnelConnectRequestAsync(object sender, TunnelConnectSessionEventArgs e) - { - // Ensures that only the targeted Https domains are proxyied - if (!IsProxiedHost(e.HttpClient.Request.RequestUri.Host) || - !IsProxiedProcess(e)) - { - e.DecryptSsl = false; - } - await Task.CompletedTask; - } - - private static int GetProcessId(TunnelConnectSessionEventArgs e) - { - if (RunTime.IsWindows) - { - return e.HttpClient.ProcessId.Value; - } - - var psi = new ProcessStartInfo - { - FileName = "lsof", - Arguments = $"-i :{e.ClientRemoteEndPoint?.Port}", - UseShellExecute = false, - RedirectStandardOutput = true, - CreateNoWindow = true - }; - var proc = new Process - { - StartInfo = psi - }; - proc.Start(); - var output = proc.StandardOutput.ReadToEnd(); - proc.WaitForExit(); - - var lines = output.Split(new[] { Environment.NewLine }, StringSplitOptions.RemoveEmptyEntries); - var matchingLine = lines.FirstOrDefault(l => l.Contains($"{e.ClientRemoteEndPoint?.Port}->")); - if (matchingLine is null) - { - return -1; - } - var pidString = Regex.Matches(matchingLine, @"^.*?\s+(\d+)")?.FirstOrDefault()?.Groups[1]?.Value; - if (pidString is null) - { - return -1; - } - - if (int.TryParse(pidString, out var pid)) - { - return pid; - } - else - { - return -1; - } - } - - private bool IsProxiedProcess(TunnelConnectSessionEventArgs e) - { - // If no process names or IDs are specified, we proxy all processes - if (!_config.WatchPids.Any() && - !_config.WatchProcessNames.Any()) - { - return true; - } - - var processId = GetProcessId(e); - if (processId == -1) - { - return false; - } - - if (_config.WatchPids.Any() && - _config.WatchPids.Contains(processId)) - { - return true; - } - - if (_config.WatchProcessNames.Any()) - { - var processName = Process.GetProcessById(processId).ProcessName; - if (_config.WatchProcessNames.Contains(processName)) - { - return true; - } - } - - return false; - } - - async Task OnRequestAsync(object sender, SessionEventArgs e) - { - if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host) && - IsIncludedByHeaders(e.HttpClient.Request.Headers)) - { - _pluginData.Add(e.GetHashCode(), []); - var responseState = new ResponseState(); - var proxyRequestArgs = new ProxyRequestArgs(e, responseState) - { - SessionData = _pluginData[e.GetHashCode()], - GlobalData = _proxyState.GlobalData - }; - if (!proxyRequestArgs.HasRequestUrlMatch(_urlsToWatch)) - { - return; - } - - // we need to keep the request body for further processing - // by plugins - e.HttpClient.Request.KeepBody = true; - if (e.HttpClient.Request.HasBody) - { - await e.GetRequestBodyAsString(); - } - - using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); - - e.UserData = e.HttpClient.Request; - _logger.LogRequest([$"{e.HttpClient.Request.Method} {e.HttpClient.Request.Url}"], MessageType.InterceptedRequest, new LoggingContext(e)); - await HandleRequestAsync(e, proxyRequestArgs); - } - } - - private async Task HandleRequestAsync(SessionEventArgs e, ProxyRequestArgs proxyRequestArgs) - { - await _pluginEvents.RaiseProxyBeforeRequestAsync(proxyRequestArgs, ExceptionHandler); - - // We only need to set the proxy header if the proxy has not set a response and the request is going to be sent to the target. - if (!proxyRequestArgs.ResponseState.HasBeenSet) - { - _logger?.LogRequest(["Passed through"], MessageType.PassedThrough, new LoggingContext(e)); - AddProxyHeader(e.HttpClient.Request); - } - } - - private static void AddProxyHeader(Request r) => r.Headers?.AddHeader("Via", $"{r.HttpVersion} dev-proxy/{ProxyUtils.ProductVersion}"); - - private bool IsProxiedHost(string hostName) => _hostsToWatch.Any(h => h.Url.IsMatch(hostName)); - - private bool IsIncludedByHeaders(HeaderCollection requestHeaders) - { - if (_config.FilterByHeaders is null) - { - return true; - } - - foreach (var header in _config.FilterByHeaders) - { - _logger.LogDebug("Checking header {header} with value {value}...", - header.Name, - string.IsNullOrEmpty(header.Value) ? "(any)" : header.Value - ); - - if (requestHeaders.HeaderExists(header.Name)) - { - if (string.IsNullOrEmpty(header.Value)) - { - _logger.LogDebug("Request has header {header}", header.Name); - return true; - } - - if (requestHeaders.GetHeaders(header.Name)!.Any(h => h.Value.Contains(header.Value))) - { - _logger.LogDebug("Request header {header} contains value {value}", header.Name, header.Value); - return true; - } - } - else - { - _logger.LogDebug("Request doesn't have header {header}", header.Name); - } - } - - _logger.LogDebug("Request doesn't match any header filter. Ignoring"); - return false; - } - - // Modify response - async Task OnBeforeResponseAsync(object sender, SessionEventArgs e) - { - // read response headers - if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) - { - var proxyResponseArgs = new ProxyResponseArgs(e, new ResponseState()) - { - SessionData = _pluginData[e.GetHashCode()], - GlobalData = _proxyState.GlobalData - }; - if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) - { - return; - } - - using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); - - // necessary to make the response body available to plugins - e.HttpClient.Response.KeepBody = true; - if (e.HttpClient.Response.HasBody) - { - await e.GetResponseBody(); - } - - await _pluginEvents.RaiseProxyBeforeResponseAsync(proxyResponseArgs, ExceptionHandler); - } - } - async Task OnAfterResponseAsync(object sender, SessionEventArgs e) - { - // read response headers - if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) - { - var proxyResponseArgs = new ProxyResponseArgs(e, new ResponseState()) - { - SessionData = _pluginData[e.GetHashCode()], - GlobalData = _proxyState.GlobalData - }; - if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) - { - // clean up - _pluginData.Remove(e.GetHashCode()); - return; - } - - // necessary to repeat to make the response body - // of mocked requests available to plugins - e.HttpClient.Response.KeepBody = true; - - using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); - - var message = $"{e.HttpClient.Request.Method} {e.HttpClient.Request.Url}"; - _logger.LogRequest([message], MessageType.InterceptedResponse, new LoggingContext(e)); - await _pluginEvents.RaiseProxyAfterResponseAsync(proxyResponseArgs, ExceptionHandler); - _logger.LogRequest([message], MessageType.FinishedProcessingRequest, new LoggingContext(e)); - - // clean up - _pluginData.Remove(e.GetHashCode()); - } - } - - // Allows overriding default certificate validation logic - Task OnCertificateValidationAsync(object sender, CertificateValidationEventArgs e) - { - // set IsValid to true/false based on Certificate Errors - if (e.SslPolicyErrors == System.Net.Security.SslPolicyErrors.None) - { - e.IsValid = true; - } - - return Task.CompletedTask; - } - - // Allows overriding default client certificate selection logic during mutual authentication - Task OnCertificateSelectionAsync(object sender, CertificateSelectionEventArgs e) - { - // set e.clientCertificate to override - return Task.CompletedTask; - } - - private static void PrintHotkeys() - { - Console.WriteLine(""); - Console.WriteLine("Hotkeys: issue (w)eb request, (r)ecord, (s)top recording, (c)lear screen"); - Console.WriteLine("Press CTRL+C to stop Dev Proxy"); - Console.WriteLine(""); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DevProxy.Abstractions; +using Microsoft.VisualStudio.Threading; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography.X509Certificates; +using System.Text.RegularExpressions; +using Titanium.Web.Proxy; +using Titanium.Web.Proxy.EventArguments; +using Titanium.Web.Proxy.Helpers; +using Titanium.Web.Proxy.Http; +using Titanium.Web.Proxy.Models; + +namespace Microsoft.DevProxy; + +enum ToggleSystemProxyAction +{ + On, + Off +} + +public class ProxyEngine(IProxyConfiguration config, ISet urlsToWatch, IPluginEvents pluginEvents, IProxyState proxyState, ILogger logger) : BackgroundService +{ + private readonly IPluginEvents _pluginEvents = pluginEvents ?? throw new ArgumentNullException(nameof(pluginEvents)); + private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + private readonly IProxyConfiguration _config = config ?? throw new ArgumentNullException(nameof(config)); + private static readonly ProxyServer? _proxyServer; + private ExplicitProxyEndPoint? _explicitEndPoint; + // lists of URLs to watch, used for intercepting requests + private readonly ISet _urlsToWatch = urlsToWatch ?? throw new ArgumentNullException(nameof(urlsToWatch)); + // lists of hosts to watch extracted from urlsToWatch, + // used for deciding which URLs to decrypt for further inspection + private readonly ISet _hostsToWatch = new HashSet(); + private readonly IProxyState _proxyState = proxyState ?? throw new ArgumentNullException(nameof(proxyState)); + // Dictionary for plugins to store data between requests + // the key is HashObject of the SessionEventArgs object + private readonly Dictionary> _pluginData = []; + + public static X509Certificate2? Certificate => _proxyServer?.CertificateManager.RootCertificate; + + private ExceptionHandler ExceptionHandler => ex => _logger.LogError(ex, "An error occurred in a plugin"); + + static ProxyEngine() + { + _proxyServer = new ProxyServer(); + _proxyServer.CertificateManager.RootCertificateName = "Dev Proxy CA"; + _proxyServer.CertificateManager.CertificateStorage = new CertificateDiskCache(); + // we need to change this to a value lower than 397 + // to avoid the ERR_CERT_VALIDITY_TOO_LONG error in Edge + _proxyServer.CertificateManager.CertificateValidDays = 365; + + var joinableTaskContext = new JoinableTaskContext(); + var joinableTaskFactory = new JoinableTaskFactory(joinableTaskContext); + _ = joinableTaskFactory.Run(async () => await _proxyServer.CertificateManager.LoadOrCreateRootCertificateAsync()); + } + + private static void ToggleSystemProxy(ToggleSystemProxyAction toggle, string? ipAddress = null, int? port = null) + { + var bashScriptPath = Path.Join(ProxyUtils.AppFolder, "toggle-proxy.sh"); + var args = toggle switch + { + ToggleSystemProxyAction.On => $"on {ipAddress} {port}", + ToggleSystemProxyAction.Off => "off", + _ => throw new NotImplementedException() + }; + + ProcessStartInfo startInfo = new ProcessStartInfo() + { + FileName = "/bin/bash", + Arguments = $"{bashScriptPath} {args}", + RedirectStandardOutput = true, + UseShellExecute = false, + CreateNoWindow = true + }; + + var process = new Process() { StartInfo = startInfo }; + process.Start(); + process.WaitForExit(); + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + Debug.Assert(_proxyServer is not null, "Proxy server is not initialized"); + + if (!_urlsToWatch.Any()) + { + _logger.LogInformation("No URLs to watch configured. Please add URLs to watch in the devproxyrc.json config file."); + return; + } + + LoadHostNamesFromUrls(); + + _proxyServer.BeforeRequest += OnRequestAsync; + _proxyServer.BeforeResponse += OnBeforeResponseAsync; + _proxyServer.AfterResponse += OnAfterResponseAsync; + _proxyServer.ServerCertificateValidationCallback += OnCertificateValidationAsync; + _proxyServer.ClientCertificateSelectionCallback += OnCertificateSelectionAsync; + + var ipAddress = string.IsNullOrEmpty(_config.IPAddress) ? IPAddress.Any : IPAddress.Parse(_config.IPAddress); + _explicitEndPoint = new ExplicitProxyEndPoint(ipAddress, _config.Port, true); + // Fired when a CONNECT request is received + _explicitEndPoint.BeforeTunnelConnectRequest += OnBeforeTunnelConnectRequestAsync; + if (_config.InstallCert) + { + await _proxyServer.CertificateManager.EnsureRootCertificateAsync(stoppingToken); + } + else + { + _explicitEndPoint.GenericCertificate = await _proxyServer + .CertificateManager + .LoadRootCertificateAsync(stoppingToken); + } + + _proxyServer.AddEndPoint(_explicitEndPoint); + await _proxyServer.StartAsync(cancellationToken: stoppingToken); + + // run first-run setup on macOS + FirstRunSetup(); + + foreach (var endPoint in _proxyServer.ProxyEndPoints) + { + _logger.LogInformation("Dev Proxy listening on {ipAddress}:{port}...", endPoint.IpAddress, endPoint.Port); + } + + if (_config.AsSystemProxy) + { + if (RunTime.IsWindows) + { + _proxyServer.SetAsSystemHttpProxy(_explicitEndPoint); + _proxyServer.SetAsSystemHttpsProxy(_explicitEndPoint); + } + else if (RunTime.IsMac) + { + ToggleSystemProxy(ToggleSystemProxyAction.On, _config.IPAddress, _config.Port); + } + else + { + _logger.LogWarning("Configure your operating system to use this proxy's port and address {ipAddress}:{port}", _config.IPAddress, _config.Port); + } + } + else + { + _logger.LogInformation("Configure your application to use this proxy's port and address"); + } + + var isInteractive = !Console.IsInputRedirected && + Environment.GetEnvironmentVariable("CI") is null; + + if (isInteractive) + { + // only print hotkeys when they can be used + PrintHotkeys(); + } + + if (_config.Record) + { + StartRecording(); + } + _pluginEvents.AfterRequestLog += AfterRequestLogAsync; + + while (!stoppingToken.IsCancellationRequested && _proxyServer.ProxyRunning) + { + while (!Console.KeyAvailable) + { + await Task.Delay(10, stoppingToken); + } + // we need this check or proxy will fail with an exception + // when run for example in VSCode's integrated terminal + if (isInteractive) + { + await ReadKeysAsync(); + } + } + } + + private void FirstRunSetup() + { + if (!RunTime.IsMac || + _config.NoFirstRun || + !IsFirstRun() || + !_config.InstallCert) + { + return; + } + + var bashScriptPath = Path.Join(ProxyUtils.AppFolder, "trust-cert.sh"); + ProcessStartInfo startInfo = new() + { + FileName = "/bin/bash", + Arguments = bashScriptPath, + UseShellExecute = true, + CreateNoWindow = false + }; + + var process = new Process() { StartInfo = startInfo }; + process.Start(); + process.WaitForExit(); + } + + private static bool IsFirstRun() + { + var firstRunFilePath = Path.Combine(ProxyUtils.AppFolder!, ".hasrun"); + if (File.Exists(firstRunFilePath)) + { + return false; + } + + try + { + File.WriteAllText(firstRunFilePath, ""); + } + catch { } + + return true; + } + + private Task AfterRequestLogAsync(object? sender, RequestLogArgs e) + { + if (!_proxyState.IsRecording) + { + return Task.CompletedTask; + } + + _proxyState.RequestLogs.Add(e.RequestLog); + return Task.CompletedTask; + } + + private async Task ReadKeysAsync() + { + var key = Console.ReadKey(true).Key; + switch (key) + { + case ConsoleKey.R: + StartRecording(); + break; + case ConsoleKey.S: + await StopRecordingAsync(); + break; + case ConsoleKey.C: + Console.Clear(); + PrintHotkeys(); + break; + case ConsoleKey.W: + await _proxyState.RaiseMockRequestAsync(); + break; + } + } + + private void StartRecording() + { + if (_proxyState.IsRecording) + { + return; + } + + _proxyState.StartRecording(); + } + + private async Task StopRecordingAsync() + { + if (!_proxyState.IsRecording) + { + return; + } + + await _proxyState.StopRecordingAsync(); + } + + // Convert strings from config to regexes. + // From the list of URLs, extract host names and convert them to regexes. + // We need this because before we decrypt a request, we only have access + // to the host name, not the full URL. + private void LoadHostNamesFromUrls() + { + foreach (var urlToWatch in _urlsToWatch) + { + // extract host from the URL + string urlToWatchPattern = Regex.Unescape(urlToWatch.Url.ToString()).Replace(".*", "*"); + string hostToWatch; + if (urlToWatchPattern.ToString().Contains("://")) + { + // if the URL contains a protocol, extract the host from the URL + var urlChunks = urlToWatchPattern.Split("://"); + var slashPos = urlChunks[1].IndexOf('/'); + hostToWatch = slashPos < 0 ? urlChunks[1] : urlChunks[1][..slashPos]; + } + else + { + // if the URL doesn't contain a protocol, + // we assume the whole URL is a host name + hostToWatch = urlToWatchPattern; + } + + // remove port number if present + var portPos = hostToWatch.IndexOf(':'); + if (portPos > 0) + { + hostToWatch = hostToWatch[..portPos]; + } + + var hostToWatchRegexString = Regex.Escape(hostToWatch).Replace("\\*", ".*"); + Regex hostRegex = new($"^{hostToWatchRegexString}$", RegexOptions.Compiled | RegexOptions.IgnoreCase); + // don't add the same host twice + if (!_hostsToWatch.Any(h => h.Url.ToString() == hostRegex.ToString())) + { + _hostsToWatch.Add(new UrlToWatch(hostRegex)); + } + } + } + + private void StopProxy() + { + // Unsubscribe & Quit + try + { + if (_explicitEndPoint != null) + { + _explicitEndPoint.BeforeTunnelConnectRequest -= OnBeforeTunnelConnectRequestAsync; + } + + if (_proxyServer is not null) + { + _proxyServer.BeforeRequest -= OnRequestAsync; + _proxyServer.BeforeResponse -= OnBeforeResponseAsync; + _proxyServer.AfterResponse -= OnAfterResponseAsync; + _proxyServer.ServerCertificateValidationCallback -= OnCertificateValidationAsync; + _proxyServer.ClientCertificateSelectionCallback -= OnCertificateSelectionAsync; + + _proxyServer.Stop(); + } + + if (RunTime.IsMac && _config.AsSystemProxy) + { + ToggleSystemProxy(ToggleSystemProxyAction.Off); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "An error occurred while stopping the proxy"); + } + } + + public override async Task StopAsync(CancellationToken cancellationToken) + { + await StopRecordingAsync(); + StopProxy(); + + await base.StopAsync(cancellationToken); + } + + async Task OnBeforeTunnelConnectRequestAsync(object sender, TunnelConnectSessionEventArgs e) + { + // Ensures that only the targeted Https domains are proxyied + if (!IsProxiedHost(e.HttpClient.Request.RequestUri.Host) || + !IsProxiedProcess(e)) + { + e.DecryptSsl = false; + } + await Task.CompletedTask; + } + + private static int GetProcessId(TunnelConnectSessionEventArgs e) + { + if (RunTime.IsWindows) + { + return e.HttpClient.ProcessId.Value; + } + + var psi = new ProcessStartInfo + { + FileName = "lsof", + Arguments = $"-i :{e.ClientRemoteEndPoint?.Port}", + UseShellExecute = false, + RedirectStandardOutput = true, + CreateNoWindow = true + }; + var proc = new Process + { + StartInfo = psi + }; + proc.Start(); + var output = proc.StandardOutput.ReadToEnd(); + proc.WaitForExit(); + + var lines = output.Split(new[] { Environment.NewLine }, StringSplitOptions.RemoveEmptyEntries); + var matchingLine = lines.FirstOrDefault(l => l.Contains($"{e.ClientRemoteEndPoint?.Port}->")); + if (matchingLine is null) + { + return -1; + } + var pidString = Regex.Matches(matchingLine, @"^.*?\s+(\d+)")?.FirstOrDefault()?.Groups[1]?.Value; + if (pidString is null) + { + return -1; + } + + if (int.TryParse(pidString, out var pid)) + { + return pid; + } + else + { + return -1; + } + } + + private bool IsProxiedProcess(TunnelConnectSessionEventArgs e) + { + // If no process names or IDs are specified, we proxy all processes + if (!_config.WatchPids.Any() && + !_config.WatchProcessNames.Any()) + { + return true; + } + + var processId = GetProcessId(e); + if (processId == -1) + { + return false; + } + + if (_config.WatchPids.Any() && + _config.WatchPids.Contains(processId)) + { + return true; + } + + if (_config.WatchProcessNames.Any()) + { + var processName = Process.GetProcessById(processId).ProcessName; + if (_config.WatchProcessNames.Contains(processName)) + { + return true; + } + } + + return false; + } + + async Task OnRequestAsync(object sender, SessionEventArgs e) + { + if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host) && + IsIncludedByHeaders(e.HttpClient.Request.Headers)) + { + _pluginData.Add(e.GetHashCode(), []); + var responseState = new ResponseState(); + var proxyRequestArgs = new ProxyRequestArgs(e, responseState) + { + SessionData = _pluginData[e.GetHashCode()], + GlobalData = _proxyState.GlobalData + }; + if (!proxyRequestArgs.HasRequestUrlMatch(_urlsToWatch)) + { + return; + } + + // we need to keep the request body for further processing + // by plugins + e.HttpClient.Request.KeepBody = true; + if (e.HttpClient.Request.HasBody) + { + await e.GetRequestBodyAsString(); + } + + using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); + + e.UserData = e.HttpClient.Request; + _logger.LogRequest($"{e.HttpClient.Request.Method} {e.HttpClient.Request.Url}", MessageType.InterceptedRequest, new LoggingContext(e)); + await HandleRequestAsync(e, proxyRequestArgs); + } + } + + private async Task HandleRequestAsync(SessionEventArgs e, ProxyRequestArgs proxyRequestArgs) + { + await _pluginEvents.RaiseProxyBeforeRequestAsync(proxyRequestArgs, ExceptionHandler); + + // We only need to set the proxy header if the proxy has not set a response and the request is going to be sent to the target. + if (!proxyRequestArgs.ResponseState.HasBeenSet) + { + _logger?.LogRequest("Passed through", MessageType.PassedThrough, new LoggingContext(e)); + AddProxyHeader(e.HttpClient.Request); + } + } + + private static void AddProxyHeader(Request r) => r.Headers?.AddHeader("Via", $"{r.HttpVersion} dev-proxy/{ProxyUtils.ProductVersion}"); + + private bool IsProxiedHost(string hostName) => _hostsToWatch.Any(h => h.Url.IsMatch(hostName)); + + private bool IsIncludedByHeaders(HeaderCollection requestHeaders) + { + if (_config.FilterByHeaders is null) + { + return true; + } + + foreach (var header in _config.FilterByHeaders) + { + _logger.LogDebug("Checking header {header} with value {value}...", + header.Name, + string.IsNullOrEmpty(header.Value) ? "(any)" : header.Value + ); + + if (requestHeaders.HeaderExists(header.Name)) + { + if (string.IsNullOrEmpty(header.Value)) + { + _logger.LogDebug("Request has header {header}", header.Name); + return true; + } + + if (requestHeaders.GetHeaders(header.Name)!.Any(h => h.Value.Contains(header.Value))) + { + _logger.LogDebug("Request header {header} contains value {value}", header.Name, header.Value); + return true; + } + } + else + { + _logger.LogDebug("Request doesn't have header {header}", header.Name); + } + } + + _logger.LogDebug("Request doesn't match any header filter. Ignoring"); + return false; + } + + // Modify response + async Task OnBeforeResponseAsync(object sender, SessionEventArgs e) + { + // read response headers + if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) + { + var proxyResponseArgs = new ProxyResponseArgs(e, new ResponseState()) + { + SessionData = _pluginData[e.GetHashCode()], + GlobalData = _proxyState.GlobalData + }; + if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) + { + return; + } + + using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); + + // necessary to make the response body available to plugins + e.HttpClient.Response.KeepBody = true; + if (e.HttpClient.Response.HasBody) + { + await e.GetResponseBody(); + } + + await _pluginEvents.RaiseProxyBeforeResponseAsync(proxyResponseArgs, ExceptionHandler); + } + } + async Task OnAfterResponseAsync(object sender, SessionEventArgs e) + { + // read response headers + if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) + { + var proxyResponseArgs = new ProxyResponseArgs(e, new ResponseState()) + { + SessionData = _pluginData[e.GetHashCode()], + GlobalData = _proxyState.GlobalData + }; + if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) + { + // clean up + _pluginData.Remove(e.GetHashCode()); + return; + } + + // necessary to repeat to make the response body + // of mocked requests available to plugins + e.HttpClient.Response.KeepBody = true; + + using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); + + var message = $"{e.HttpClient.Request.Method} {e.HttpClient.Request.Url}"; + _logger.LogRequest(message, MessageType.InterceptedResponse, new LoggingContext(e)); + await _pluginEvents.RaiseProxyAfterResponseAsync(proxyResponseArgs, ExceptionHandler); + _logger.LogRequest(message, MessageType.FinishedProcessingRequest, new LoggingContext(e)); + + // clean up + _pluginData.Remove(e.GetHashCode()); + } + } + + // Allows overriding default certificate validation logic + Task OnCertificateValidationAsync(object sender, CertificateValidationEventArgs e) + { + // set IsValid to true/false based on Certificate Errors + if (e.SslPolicyErrors == System.Net.Security.SslPolicyErrors.None) + { + e.IsValid = true; + } + + return Task.CompletedTask; + } + + // Allows overriding default client certificate selection logic during mutual authentication + Task OnCertificateSelectionAsync(object sender, CertificateSelectionEventArgs e) + { + // set e.clientCertificate to override + return Task.CompletedTask; + } + + private static void PrintHotkeys() + { + Console.WriteLine(""); + Console.WriteLine("Hotkeys: issue (w)eb request, (r)ecord, (s)top recording, (c)lear screen"); + Console.WriteLine("Press CTRL+C to stop Dev Proxy"); + Console.WriteLine(""); + } +} diff --git a/dev-proxy/devproxyrc.json b/dev-proxy/devproxyrc.json index bb9a6a11..fbeb8491 100644 --- a/dev-proxy/devproxyrc.json +++ b/dev-proxy/devproxyrc.json @@ -21,5 +21,6 @@ }, "rate": 50, "logLevel": "information", - "newVersionNotification": "stable" + "newVersionNotification": "stable", + "showSkipMessages": true } \ No newline at end of file