fix(phi-1_5): Checks length of attention_maskif it is passed as direct tensor.
This commit is contained in:
parent
3128bb636a
commit
f9f2ac7c45
@ -35,7 +35,7 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
@ -541,8 +541,8 @@ class MHA(nn.Module):
|
||||
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask, cu_seqlens, max_seqlen = attention_mask
|
||||
attention_mask = attention_mask.to(qkv.device)
|
||||
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
|
||||
attention_mask = attention_mask.bool().to(qkv.device)
|
||||
|
||||
attention_kwargs = {"attention_mask": attention_mask}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user