Update 1/model.py

This commit is contained in:
cheetahadmin 2025-05-28 07:48:20 +00:00
parent ece50c848b
commit 30e3d35853

@ -2,7 +2,6 @@ import triton_python_backend_utils as pb_utils
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import numpy as np import numpy as np
import json import json
import os
class TritonPythonModel: class TritonPythonModel:
def initialize(self, args): def initialize(self, args):
@ -13,35 +12,31 @@ class TritonPythonModel:
""" """
self.logger = pb_utils.Logger self.logger = pb_utils.Logger
current_file_path = os.path.abspath(__file__)
self.logger.log_info(f"current_file_path: {current_file_path}")
self.model_name = args["model_name"] self.model_name = args["model_name"]
model_repository = args["model_repository"] self.model_path = self._get_config_parameter("model_path")
#model_path = model_repository self.enable_inference_trace = self._get_config_parameter("enable_inference_trace")
model_path = "/cheetah/input/model/groupuser/base-gemma-3-1b-it"
self.logger.log_info(f"model_name: {self.model_name}") self.logger.log_info(f"'self.model_name: {self.model_name}'")
self.logger.log_info(f"model_repository: {model_repository}") self.logger.log_info(f"'model_path: {self.model_path}'")
self.logger.log_info(f"model_path: {model_path}") self.logger.log_info(f"'enable_inference_trace: {self.enable_inference_trace}'")
#model_repository = args["model_repository"]
#model_path = f"{model_repository}/{self.model_name}"
self.model_config = json.loads(args["model_config"]) self.model_config = json.loads(args["model_config"])
# Hugging Face Transformers 라이브러리에서 사전 학습된 토크나이저를 로드합니다. # Hugging Face Transformers 라이브러리에서 사전 학습된 토크나이저를 로드합니다.
self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.supports_chat_template = self._check_chat_template_support() self.supports_chat_template = self._check_chat_template_support()
# Hugging Face Transformers 라이브러리에서 사전 학습된 언어 모델을 로드합니다. # Hugging Face Transformers 라이브러리에서 사전 학습된 언어 모델을 로드합니다.
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path, pretrained_model_name_or_path=self.model_path,
local_files_only=True, local_files_only=True,
device_map="auto" device_map="auto"
) )
self.enable_inference_trace = self._get_inference_trace_setting()
self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료") self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료")
@ -149,6 +144,15 @@ class TritonPythonModel:
return generation_config return generation_config
def _get_config_parameter(self, parameter_name):
self.parameters = self.model_config.get('parameters', {})
parameter_dict = self.parameters.get(parameter_name)
if isinstance(parameter_dict, dict) and 'string_value' in parameter_dict:
return parameter_dict['string_value']
return None
def _get_inference_trace_setting(self): def _get_inference_trace_setting(self):
""" """
모델 설정(config.pbxt)에서 'enable_inference_trace' 값을 추출하여 반환합니다. 모델 설정(config.pbxt)에서 'enable_inference_trace' 값을 추출하여 반환합니다.
@ -158,12 +162,20 @@ class TritonPythonModel:
Returns: Returns:
bool: 추론 추적 활성화 여부 (True 또는 False). bool: 추론 추적 활성화 여부 (True 또는 False).
""" """
parameters = self.model_config.get('parameters', {}) trace_config_dict = self.parameters.get('enable_inference_trace')
trace_config = parameters.get('enable_inference_trace') if isinstance(trace_config_dict, dict) and 'string_value' in trace_config_dict:
if isinstance(trace_config, dict) and 'string_value' in trace_config: return trace_config_dict['string_value'].lower() == 'true' # 문자열 값을 bool로 변환하여 반환
return trace_config['string_value'].lower() == 'true' # 문자열 값을 bool로 변환하여 반환
return False return False
def _get_model_path(self):
"""
모델 설정(config.pbxt)에서 'model_path' 값을 추출하여 반환합니다.
"""
model_path_dict = self.parameters.get('model_path')
if isinstance(model_path_dict, dict) and 'string_value' in model_path_dict:
return model_path_dict['string_value']
return ""
def _check_chat_template_support(self): def _check_chat_template_support(self):
""" """
@ -232,4 +244,3 @@ class TritonPythonModel:
종료 전에 필요한 모든 정리 작업을 수행할 있습니다. 종료 전에 필요한 모든 정리 작업을 수행할 있습니다.
""" """
pass pass