diff --git a/1/model.py b/1/model.py index 3f23d1b..a097585 100644 --- a/1/model.py +++ b/1/model.py @@ -2,7 +2,6 @@ import triton_python_backend_utils as pb_utils from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import numpy as np import json -import os class TritonPythonModel: def initialize(self, args): @@ -13,35 +12,31 @@ class TritonPythonModel: """ self.logger = pb_utils.Logger - current_file_path = os.path.abspath(__file__) - self.logger.log_info(f"current_file_path: {current_file_path}") - - self.model_name = args["model_name"] - model_repository = args["model_repository"] - #model_path = model_repository - model_path = "/cheetah/input/model/groupuser/base-gemma-3-1b-it" + self.model_path = self._get_config_parameter("model_path") + self.enable_inference_trace = self._get_config_parameter("enable_inference_trace") - self.logger.log_info(f"model_name: {self.model_name}") - self.logger.log_info(f"model_repository: {model_repository}") - self.logger.log_info(f"model_path: {model_path}") + self.logger.log_info(f"'self.model_name: {self.model_name}'") + self.logger.log_info(f"'model_path: {self.model_path}'") + self.logger.log_info(f"'enable_inference_trace: {self.enable_inference_trace}'") + + #model_repository = args["model_repository"] + #model_path = f"{model_repository}/{self.model_name}" self.model_config = json.loads(args["model_config"]) # Hugging Face Transformers 라이브러리에서 사전 학습된 토크나이저를 로드합니다. - self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.supports_chat_template = self._check_chat_template_support() # Hugging Face Transformers 라이브러리에서 사전 학습된 언어 모델을 로드합니다. self.model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=model_path, + pretrained_model_name_or_path=self.model_path, local_files_only=True, device_map="auto" ) - - self.enable_inference_trace = self._get_inference_trace_setting() - + self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료") @@ -149,6 +144,15 @@ class TritonPythonModel: return generation_config + def _get_config_parameter(self, parameter_name): + self.parameters = self.model_config.get('parameters', {}) + parameter_dict = self.parameters.get(parameter_name) + + if isinstance(parameter_dict, dict) and 'string_value' in parameter_dict: + return parameter_dict['string_value'] + + return None + def _get_inference_trace_setting(self): """ 모델 설정(config.pbxt)에서 'enable_inference_trace' 값을 추출하여 반환합니다. @@ -158,11 +162,19 @@ class TritonPythonModel: Returns: bool: 추론 추적 활성화 여부 (True 또는 False). """ - parameters = self.model_config.get('parameters', {}) - trace_config = parameters.get('enable_inference_trace') - if isinstance(trace_config, dict) and 'string_value' in trace_config: - return trace_config['string_value'].lower() == 'true' # 문자열 값을 bool로 변환하여 반환 + trace_config_dict = self.parameters.get('enable_inference_trace') + if isinstance(trace_config_dict, dict) and 'string_value' in trace_config_dict: + return trace_config_dict['string_value'].lower() == 'true' # 문자열 값을 bool로 변환하여 반환 return False + + def _get_model_path(self): + """ + 모델 설정(config.pbxt)에서 'model_path' 값을 추출하여 반환합니다. + """ + model_path_dict = self.parameters.get('model_path') + if isinstance(model_path_dict, dict) and 'string_value' in model_path_dict: + return model_path_dict['string_value'] + return "" def _check_chat_template_support(self): @@ -231,5 +243,4 @@ class TritonPythonModel: `finalize` 함수를 구현하는 것은 선택 사항입니다. 이 함수를 통해 모델은 종료 전에 필요한 모든 정리 작업을 수행할 수 있습니다. """ - pass - + pass \ No newline at end of file