fix(phi-1_5): Checks length of attention_maskif it is passed as direct tensor.

This commit is contained in:
Gustavo de Rosa 2023-09-26 21:21:45 +00:00 committed by huggingface-web
parent 3128bb636a
commit f9f2ac7c45

@ -35,7 +35,7 @@ from __future__ import annotations
import math import math
import copy import copy
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
@ -541,8 +541,8 @@ class MHA(nn.Module):
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx) kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
if attention_mask is not None: if attention_mask is not None:
attention_mask, cu_seqlens, max_seqlen = attention_mask attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
attention_mask = attention_mask.to(qkv.device) attention_mask = attention_mask.bool().to(qkv.device)
attention_kwargs = {"attention_mask": attention_mask} attention_kwargs = {"attention_mask": attention_mask}