|
| 1 | +{-# LANGUAGE OverloadedStrings #-} |
| 2 | +{-# LANGUAGE RankNTypes #-} |
| 3 | +{-# LANGUAGE RecordWildCards #-} |
| 4 | +{-# LANGUAGE LambdaCase #-} |
| 5 | +{-# LANGUAGE DeriveAnyClass #-} |
| 6 | +{-# LANGUAGE DerivingStrategies #-} |
| 7 | +{-# LANGUAGE GeneralizedNewtypeDeriving #-} |
| 8 | + |
| 9 | +module Chainweb.Utils.Throttling |
| 10 | + ( ThrottleEconomy(..) |
| 11 | + , ThrottledException(..) |
| 12 | + , throttleMiddleware |
| 13 | + ) where |
| 14 | + |
| 15 | +import Data.LogMessage |
| 16 | +import Data.Text (Text) |
| 17 | +import qualified Network.Wai as Wai |
| 18 | +import qualified Network.Wai.Internal as Wai.Internal |
| 19 | +import Chainweb.Utils.TokenLimiting |
| 20 | +import Control.Exception.Safe |
| 21 | +import Network.HTTP.Types.Status |
| 22 | +import qualified Data.ByteString as BS |
| 23 | +import qualified Data.Text as T |
| 24 | +import Data.Hashable |
| 25 | +import Network.Socket (SockAddr(..)) |
| 26 | +import qualified Data.ByteString.Builder as BSB |
| 27 | +import System.IO.Unsafe (unsafeInterleaveIO) |
| 28 | +import qualified Data.ByteString.Lazy as LBS |
| 29 | + |
| 30 | +data ThrottleEconomy = ThrottleEconomy |
| 31 | + { requestCost :: Int |
| 32 | + , requestBody100ByteCost :: Int |
| 33 | + , responseBody100ByteCost :: Int |
| 34 | + , maxBudget :: Int |
| 35 | + , freeRate :: Int |
| 36 | + } |
| 37 | + |
| 38 | +data ThrottledException = ThrottledException Text |
| 39 | + deriving (Show, Exception) |
| 40 | + |
| 41 | +hashWithSalt' :: Hashable a => a -> Int -> Int |
| 42 | +hashWithSalt' = flip hashWithSalt |
| 43 | + |
| 44 | +newtype HashableSockAddr = HashableSockAddr SockAddr |
| 45 | + deriving newtype Eq |
| 46 | +instance Hashable HashableSockAddr where |
| 47 | + hashWithSalt salt (HashableSockAddr sockAddr) = case sockAddr of |
| 48 | + SockAddrInet port hostAddr -> |
| 49 | + -- constructor tag |
| 50 | + hashWithSalt' (1 :: Word) |
| 51 | + . hashWithSalt' (fromIntegral port :: Word) |
| 52 | + . hashWithSalt' hostAddr |
| 53 | + $ salt |
| 54 | + SockAddrInet6 port flowInfo hostAddr scopeId -> |
| 55 | + hashWithSalt' (2 :: Word) |
| 56 | + . hashWithSalt' (fromIntegral port :: Word) |
| 57 | + . hashWithSalt' flowInfo |
| 58 | + . hashWithSalt' hostAddr |
| 59 | + . hashWithSalt' scopeId |
| 60 | + $ salt |
| 61 | + SockAddrUnix str -> |
| 62 | + hashWithSalt' (3 :: Word) |
| 63 | + . hashWithSalt' str |
| 64 | + $ salt |
| 65 | + |
| 66 | +debitOrDie :: Hashable k => TokenLimitMap k -> (Text, k) -> Int -> IO () |
| 67 | +debitOrDie tokenLimitMap (name, k) cost = do |
| 68 | + tryDebit cost k tokenLimitMap >>= \case |
| 69 | + True -> return () |
| 70 | + False -> throwIO (ThrottledException name) |
| 71 | + |
| 72 | +throttleMiddleware :: LogFunction -> Text -> ThrottleEconomy -> (Wai.Middleware -> IO r) -> IO r |
| 73 | +throttleMiddleware logfun name ThrottleEconomy{..} k = |
| 74 | + withTokenLimitMap logfun ("request-throttler-" <> name) limitCachePolicy limitConfig $ \tokenLimitMap -> do |
| 75 | + k $ middleware tokenLimitMap |
| 76 | + where |
| 77 | + middleware tokenLimitMap app request respond = do |
| 78 | + debitOrDie' requestCost |
| 79 | + meteredRequest <- meterRequest debitOrDie' request |
| 80 | + app meteredRequest (meterResponse debitOrDie' respond) |
| 81 | + where |
| 82 | + host = HashableSockAddr $ Wai.remoteHost request |
| 83 | + hostText = T.pack $ show (Wai.remoteHost request) |
| 84 | + debitOrDie' = debitOrDie tokenLimitMap (hostText, host) |
| 85 | + |
| 86 | + limitCachePolicy = TokenLimitCachePolicy 30 |
| 87 | + limitConfig = defaultLimitConfig |
| 88 | + { maxBucketTokens = maxBudget |
| 89 | + , initialBucketTokens = maxBudget |
| 90 | + , bucketRefillTokensPerSecond = freeRate |
| 91 | + } |
| 92 | + |
| 93 | + meterRequest debit request |
| 94 | + | requestBody100ByteCost == 0 = return request |
| 95 | + | otherwise = case Wai.requestBodyLength request of |
| 96 | + Wai.KnownLength requestBodyLen -> do |
| 97 | + () <- debit $ (requestBody100ByteCost * fromIntegral requestBodyLen) `div` 100 |
| 98 | + return request |
| 99 | + Wai.ChunkedBody -> |
| 100 | + return (Wai.setRequestBodyChunks (getMeteredRequestBodyChunk debit request) request) |
| 101 | + |
| 102 | + getMeteredRequestBodyChunk debit request = do |
| 103 | + chunk <- Wai.getRequestBodyChunk request |
| 104 | + -- charge *after* receiving a request body chunk |
| 105 | + () <- debit $ (requestBody100ByteCost * BS.length chunk) `div` 100 |
| 106 | + return chunk |
| 107 | + |
| 108 | + -- the only way to match on responses without using internal API is via |
| 109 | + -- responseToStream, which converts any response into a streaming response. |
| 110 | + -- unfortunately: |
| 111 | + -- * all of the responses produced by servant are builder responses, |
| 112 | + -- not streaming responses |
| 113 | + -- * streaming responses are not supported by http2; we try to use http2 |
| 114 | + -- (see https://hackage.haskell.org/package/http2-5.3.5/docs/src/Network.HTTP2.Server.Run.html#runIO) |
| 115 | + -- * a streaming response body may be less efficient than a builder |
| 116 | + -- response body, in particular because it needs to use a chunked |
| 117 | + -- encoding |
| 118 | + -- |
| 119 | + meterResponse |
| 120 | + :: (Int -> IO ()) |
| 121 | + -> (Wai.Response -> IO a) -> Wai.Response -> IO a |
| 122 | + meterResponse _ respond response |
| 123 | + | responseBody100ByteCost == 0 = respond response |
| 124 | + meterResponse debit respond (Wai.Internal.ResponseStream status headers responseBody) = do |
| 125 | + respond |
| 126 | + $ Wai.responseStream status headers |
| 127 | + $ meterStreamingResponseBody debit responseBody |
| 128 | + meterResponse debit respond (Wai.Internal.ResponseBuilder status headers responseBody) = do |
| 129 | + respond |
| 130 | + <$> Wai.responseLBS status headers . LBS.fromChunks |
| 131 | + =<< meterBuilderResponseBody debit (LBS.toChunks $ BSB.toLazyByteString responseBody) |
| 132 | + meterResponse _ _ _ = error "unrecognized response type" |
| 133 | + |
| 134 | + meterStreamingResponseBody debit responseBody send flush = responseBody |
| 135 | + (\chunkBSBuilder -> do |
| 136 | + let chunkBS = BS.toStrict (BSB.toLazyByteString chunkBSBuilder) |
| 137 | + () <- debit $ (responseBody100ByteCost * BS.length chunkBS) `div` 100 |
| 138 | + -- charger *before* sending a response body chunk |
| 139 | + send (BSB.byteString chunkBS) |
| 140 | + ) |
| 141 | + flush |
| 142 | + meterBuilderResponseBody debit (chunk:chunks) = unsafeInterleaveIO $ do |
| 143 | + () <- debit $ (responseBody100ByteCost * BS.length chunk) `div` 100 |
| 144 | + (chunk:) <$> meterBuilderResponseBody debit chunks |
| 145 | + meterBuilderResponseBody _ [] = return [] |
0 commit comments