fix(phi-1_5): Checks length of attention_maskif it is passed as direct tensor.
This commit is contained in:
parent
3128bb636a
commit
f9f2ac7c45
@ -35,7 +35,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -541,8 +541,8 @@ class MHA(nn.Module):
|
|||||||
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask, cu_seqlens, max_seqlen = attention_mask
|
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
|
||||||
attention_mask = attention_mask.to(qkv.device)
|
attention_mask = attention_mask.bool().to(qkv.device)
|
||||||
|
|
||||||
attention_kwargs = {"attention_mask": attention_mask}
|
attention_kwargs = {"attention_mask": attention_mask}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user