fix(phi-1_5): Checks length of attention_maskif it is passed as direct tensor.

This commit is contained in:
Gustavo de Rosa 2023-09-26 21:21:45 +00:00 committed by huggingface-web
parent 3128bb636a
commit f9f2ac7c45

@ -35,7 +35,7 @@ from __future__ import annotations
import math
import copy
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
from dataclasses import dataclass, field
import torch
@ -541,8 +541,8 @@ class MHA(nn.Module):
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
if attention_mask is not None:
attention_mask, cu_seqlens, max_seqlen = attention_mask
attention_mask = attention_mask.to(qkv.device)
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
attention_mask = attention_mask.bool().to(qkv.device)
attention_kwargs = {"attention_mask": attention_mask}