Update 1/model.py
This commit is contained in:
parent
7437bf7190
commit
203f9072f1
95
1/model.py
95
1/model.py
@ -1,13 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
# AutoModelForCausalLM 대신 AutoModel을 가져옵니다.
|
|
||||||
from transformers import AutoModel, AutoTokenizer, GenerationConfig
|
from transformers import AutoModel, AutoTokenizer, GenerationConfig
|
||||||
from peft import PeftModel, PeftConfig
|
from peft import PeftModel, PeftConfig
|
||||||
|
|
||||||
# BitsAndBytesConfig가 import 되지 않아 임시로 주석 처리하거나, 필요하다면 설치 후 주석 해제해야 합니다.
|
|
||||||
# from bitsandbytes import BitsAndBytesConfig
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
@ -21,81 +19,30 @@ class TritonPythonModel:
|
|||||||
self.model_name = args["model_name"]
|
self.model_name = args["model_name"]
|
||||||
self.base_model_path = self._get_config_parameter("base_model_path")
|
self.base_model_path = self._get_config_parameter("base_model_path")
|
||||||
|
|
||||||
# CodeSage는 임베딩 모델이므로 LoRA 등의 어댑터 로드는 지원하지 않거나 일반적이지 않습니다.
|
|
||||||
# 기존 로직은 유지하되, 실제로 사용하지 않을 경우 config.pbtxt에서 해당 파라미터를 제거하는 것이 좋습니다.
|
|
||||||
self.is_adapter_model = self._get_config_parameter("is_adapter_model").strip().lower() == "true"
|
|
||||||
self.adapter_model_path = self._get_config_parameter("adapter_model_path")
|
|
||||||
self.quantization = self._get_config_parameter("quantization")
|
|
||||||
|
|
||||||
self.logger.log_info(f"base_model_path: {self.base_model_path}")
|
self.logger.log_info(f"base_model_path: {self.base_model_path}")
|
||||||
self.logger.log_info(f"is_adapter_model: {self.is_adapter_model}")
|
|
||||||
self.logger.log_info(f"adapter_model_path: {self.adapter_model_path}")
|
|
||||||
self.logger.log_info(f"quantization: {self.quantization}")
|
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""
|
torch_dtype = torch.float16
|
||||||
Load model: CodeSage에 맞게 AutoModel과 trust_remote_code=True를 사용하도록 수정
|
|
||||||
"""
|
|
||||||
self.bnb_config = None
|
|
||||||
torch_dtype = torch.float16 # 기본 dtype
|
|
||||||
|
|
||||||
# Note: CodeSage는 AutoModel을 사용하며, 일반적인 CausalLM 양자화 옵션이 적용되는지
|
# AutoModel로 모델 로드
|
||||||
# 확인이 필요하지만, 일단 기존 로직은 유지합니다.
|
self.model = AutoModel.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=self.base_model_path,
|
||||||
# BitsAndBytesConfig가 정의되지 않았으므로 주석 처리합니다.
|
local_files_only=True,
|
||||||
# if self.quantization == "int4":
|
device_map="auto",
|
||||||
# # ... (int4 config)
|
trust_remote_code=False
|
||||||
# pass
|
)
|
||||||
# elif self.quantization == "int8":
|
|
||||||
# # ... (int8 config)
|
|
||||||
# pass
|
|
||||||
# else:
|
|
||||||
# self.bnb_config = None
|
|
||||||
|
|
||||||
if self.is_adapter_model:
|
|
||||||
# CodeSage는 임베딩 모델이므로 어댑터 사용은 일반적이지 않으나,
|
|
||||||
# 기존 템플릿 로직은 유지합니다.
|
|
||||||
peft_config = PeftConfig.from_pretrained(self.adapter_model_path)
|
|
||||||
self.base_model_path = peft_config.base_model_name_or_path
|
|
||||||
|
|
||||||
# base 모델 로드: AutoModel로 변경
|
|
||||||
base_model = AutoModel.from_pretrained(
|
|
||||||
peft_config.base_model_name_or_path,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
quantization_config=self.bnb_config if self.bnb_config else None,
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=True # 필수 옵션 추가
|
|
||||||
)
|
|
||||||
|
|
||||||
# adapter 모델 로드
|
|
||||||
self.model = PeftModel.from_pretrained(base_model, self.adapter_model_path)
|
|
||||||
else:
|
|
||||||
# 일반 모델인 경우 로드: AutoModel로 변경
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=self.base_model_path,
|
|
||||||
# local_files_only=True, # CodeSage는 허브에서 로드될 수 있으므로 주석 처리
|
|
||||||
quantization_config=self.bnb_config if self.bnb_config else None,
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=True # 필수 옵션 추가
|
|
||||||
)
|
|
||||||
|
|
||||||
# 모델을 평가 모드로 설정
|
# 모델을 평가 모드로 설정
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Tokenizer 로드: trust_remote_code=True와 add_eos_token=True 추가
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.base_model_path,
|
self.base_model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=False,
|
||||||
add_eos_token=True # 필수 옵션 추가
|
add_eos_token=True
|
||||||
)
|
)
|
||||||
# 임베딩 모델이므로 pad_token_id 설정은 불필요하거나, 주의가 필요함. 일단 주석 처리
|
|
||||||
# self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
||||||
|
|
||||||
# 임베딩 모델에서는 Chat Template 지원 체크는 불필요하므로 제거
|
|
||||||
# self.supports_chat_template = self._check_chat_template_support()
|
|
||||||
self.supports_chat_template = False # 항상 False로 설정
|
|
||||||
|
|
||||||
self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료 (Code Embedding Mode)")
|
self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료 (Code Embedding Mode)")
|
||||||
|
|
||||||
@ -158,23 +105,7 @@ class TritonPythonModel:
|
|||||||
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
# 임베딩 모델에서는 불필요하므로 제거합니다.
|
|
||||||
# def _process_generation_config(self, request):
|
|
||||||
# """
|
|
||||||
# 추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다.
|
|
||||||
# """
|
|
||||||
# # ... (기존 로직)
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# 임베딩 모델에서는 불필요하므로 제거합니다.
|
|
||||||
# def _check_chat_template_support(self):
|
|
||||||
# """
|
|
||||||
# 주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다.
|
|
||||||
# """
|
|
||||||
# # ... (기존 로직)
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
def _get_config_parameter(self, parameter_name):
|
def _get_config_parameter(self, parameter_name):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user