Compare commits

..

7 Commits

Author SHA1 Message Date
a91c637c60 Update model.py 2025-05-13 08:40:07 +00:00
a3dab09801 Create New File 2025-04-30 07:31:55 +00:00
203891813b modify model path 2025-04-29 13:59:26 +09:00
ddb8145bd1 print file list at input model directory 2025-04-29 13:46:56 +09:00
941523384f modify model path 2025-04-29 13:29:04 +09:00
6f5c252f55 add log 2025-04-29 13:00:21 +09:00
b19c4d2940 restructurized 2025-04-29 10:47:51 +09:00
12 changed files with 38 additions and 13 deletions

@ -2,6 +2,7 @@ import triton_python_backend_utils as pb_utils
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import numpy as np
import json
import os
class TritonPythonModel:
def initialize(self, args):
@ -12,9 +13,27 @@ class TritonPythonModel:
"""
self.logger = pb_utils.Logger
current_file_path = os.path.abspath(__file__)
self.logger.log_info(f"current_file_path: {current_file_path}")
self.model_name = args["model_name"]
model_repository = args["model_repository"]
model_path = f"{model_repository}/{self.model_name}"
#model_path = "/cheetah/input/model/gemma-3-1b-it/gemma-3-1b-it"
input_model_path = model_path
if os.path.exists(input_model_path):
file_list = os.listdir(input_model_path)
self.logger.log_info(f"'{input_model_path}' 디렉토리의 파일 목록:")
for file_name in file_list:
self.logger.log_info(file_name)
else:
self.logger.log_info(f"'{input_model_path}' 디렉토리가 존재하지 않습니다.")
self.logger.log_info(f"model_repository: {model_repository}")
self.logger.log_info(f"model_path: {model_path}")
self.model_config = json.loads(args["model_config"])
@ -43,9 +62,9 @@ class TritonPythonModel:
# 각 추론 요청을 순회하며 처리합니다.
for request in requests:
# Triton 입력 파싱
# Triton 입력 파싱
input_text = self._get_input_value(request, "text_input")
text = ""
conversation = ""
input_token_length = 0 # 입력 토큰 길이를 저장할 변수
@ -60,7 +79,7 @@ class TritonPythonModel:
text = input_text
is_chat = False
self.logger.log_info(f"입력 text 출력:\n{text}")
# 입력 텍스트를 토큰화합니다.
if self.supports_chat_template and is_chat:
self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.")
@ -104,7 +123,7 @@ class TritonPythonModel:
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
return responses
def _process_generation_config(self, request):
"""
추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다.
@ -153,12 +172,12 @@ class TritonPythonModel:
if isinstance(trace_config, dict) and 'string_value' in trace_config:
return trace_config['string_value'].lower() == 'true' # 문자열 값을 bool로 변환하여 반환
return False
def _check_chat_template_support(self):
"""
주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다.
Returns:
bool: Chat 템플릿 지원 여부 (True 또는 False).
"""
@ -174,7 +193,7 @@ class TritonPythonModel:
except Exception as e:
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}")
return False
def _get_input_value(self, request, input_name: str, default=None):
"""
@ -192,9 +211,9 @@ class TritonPythonModel:
if tensor_value is None:
return default
return self._np_decoder(tensor_value.as_numpy()[0])
def _np_decoder(self, obj):
"""
NumPy 객체의 데이터 타입을 확인하고 Python 기본 타입으로 변환합니다.
@ -221,4 +240,4 @@ class TritonPythonModel:
`finalize` 함수를 구현하는 것은 선택 사항입니다. 함수를 통해 모델은
종료 전에 필요한 모든 정리 작업을 수행할 있습니다.
"""
pass
pass

@ -1,6 +1,7 @@
# Triton backend to use
name: "gemma-3-1b-it"
max_batch_size: 0
backend: "python"
max_batch_size: 0
# Triton should expect as input a single string
# input of variable length named 'text_input'
@ -63,7 +64,7 @@ input [
# Triton should expect to respond with a single string
# output of variable length named 'text_output'
output [
output [
{
name: "text_output"
data_type: TYPE_STRING
@ -71,6 +72,7 @@ output [
}
]
parameters: [
{
key: "enable_inference_trace",
@ -83,4 +85,5 @@ instance_group [
kind: KIND_AUTO,
count: 1
}
]
]

@ -1,3 +1,6 @@
test.txt
test.txt
test.txt
test.txt
test.txt