diff --git a/1/model.py b/1/model.py index c3634d8..dc255ac 100644 --- a/1/model.py +++ b/1/model.py @@ -69,7 +69,7 @@ class TritonPythonModel: try: self.model = AutoModelForCausalLM.from_pretrained( load_path, - torch_dtype=torch.float16, + torch_dtype="auto", quantization_config=bnb_config, device_map="auto", local_files_only=True,