diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index cb54b4c3ba66..c1f3bb0ca33c 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -106,6 +106,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", + attn_type=self.attn_type, ) def _init_qkv( @@ -134,12 +135,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=self.attn_type) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) @@ -164,6 +160,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=prefix, + attn_type=AttentionType.ENCODER_DECODER, ) def _init_qkv( @@ -207,12 +204,13 @@ def forward( else: k = v = None - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + attn_output = self.attn( + q, + k, + v, + kv_cache, + attn_metadata, + ) output, _ = self.out_proj(attn_output) @@ -734,4 +732,4 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_weights = [(name, loaded_weight) for name, loaded_weight in weights] mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}) - return loader.load_weights(loaded_weights, mapper=mapper) \ No newline at end of file + return loader.load_weights(loaded_weights, mapper=mapper)