diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index d22bbc2..22b75f1 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -35,7 +35,7 @@ from __future__ import annotations import math import copy -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union from dataclasses import dataclass, field import torch @@ -541,8 +541,8 @@ class MHA(nn.Module): kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx) if attention_mask is not None: - attention_mask, cu_seqlens, max_seqlen = attention_mask - attention_mask = attention_mask.to(qkv.device) + attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask + attention_mask = attention_mask.bool().to(qkv.device) attention_kwargs = {"attention_mask": attention_mask}