Compare commits
3 Commits
main
...
refs/deplo
| Author | SHA1 | Date | |
|---|---|---|---|
| 766fca2517 | |||
| 5b9d078db3 | |||
| 98684d5814 |
39
1/model.py
39
1/model.py
@ -2,7 +2,6 @@ import triton_python_backend_utils as pb_utils
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
@ -13,27 +12,9 @@ class TritonPythonModel:
|
|||||||
"""
|
"""
|
||||||
self.logger = pb_utils.Logger
|
self.logger = pb_utils.Logger
|
||||||
|
|
||||||
current_file_path = os.path.abspath(__file__)
|
|
||||||
self.logger.log_info(f"current_file_path: {current_file_path}")
|
|
||||||
|
|
||||||
|
|
||||||
self.model_name = args["model_name"]
|
self.model_name = args["model_name"]
|
||||||
model_repository = args["model_repository"]
|
model_repository = args["model_repository"]
|
||||||
model_path = f"{model_repository}/{self.model_name}"
|
model_path = f"{model_repository}/{self.model_name}"
|
||||||
#model_path = "/cheetah/input/model/gemma-3-1b-it/gemma-3-1b-it"
|
|
||||||
|
|
||||||
input_model_path = model_path
|
|
||||||
|
|
||||||
if os.path.exists(input_model_path):
|
|
||||||
file_list = os.listdir(input_model_path)
|
|
||||||
self.logger.log_info(f"'{input_model_path}' 디렉토리의 파일 목록:")
|
|
||||||
for file_name in file_list:
|
|
||||||
self.logger.log_info(file_name)
|
|
||||||
else:
|
|
||||||
self.logger.log_info(f"'{input_model_path}' 디렉토리가 존재하지 않습니다.")
|
|
||||||
|
|
||||||
self.logger.log_info(f"model_repository: {model_repository}")
|
|
||||||
self.logger.log_info(f"model_path: {model_path}")
|
|
||||||
|
|
||||||
self.model_config = json.loads(args["model_config"])
|
self.model_config = json.loads(args["model_config"])
|
||||||
|
|
||||||
@ -62,9 +43,9 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
# 각 추론 요청을 순회하며 처리합니다.
|
# 각 추론 요청을 순회하며 처리합니다.
|
||||||
for request in requests:
|
for request in requests:
|
||||||
# Triton 입력 파싱
|
# Triton 입력 파싱
|
||||||
input_text = self._get_input_value(request, "text_input")
|
input_text = self._get_input_value(request, "text_input")
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
conversation = ""
|
conversation = ""
|
||||||
input_token_length = 0 # 입력 토큰 길이를 저장할 변수
|
input_token_length = 0 # 입력 토큰 길이를 저장할 변수
|
||||||
@ -79,7 +60,7 @@ class TritonPythonModel:
|
|||||||
text = input_text
|
text = input_text
|
||||||
is_chat = False
|
is_chat = False
|
||||||
self.logger.log_info(f"입력 text 출력:\n{text}")
|
self.logger.log_info(f"입력 text 출력:\n{text}")
|
||||||
|
|
||||||
# 입력 텍스트를 토큰화합니다.
|
# 입력 텍스트를 토큰화합니다.
|
||||||
if self.supports_chat_template and is_chat:
|
if self.supports_chat_template and is_chat:
|
||||||
self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.")
|
self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.")
|
||||||
@ -123,7 +104,7 @@ class TritonPythonModel:
|
|||||||
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def _process_generation_config(self, request):
|
def _process_generation_config(self, request):
|
||||||
"""
|
"""
|
||||||
추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다.
|
추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다.
|
||||||
@ -172,12 +153,12 @@ class TritonPythonModel:
|
|||||||
if isinstance(trace_config, dict) and 'string_value' in trace_config:
|
if isinstance(trace_config, dict) and 'string_value' in trace_config:
|
||||||
return trace_config['string_value'].lower() == 'true' # 문자열 값을 bool로 변환하여 반환
|
return trace_config['string_value'].lower() == 'true' # 문자열 값을 bool로 변환하여 반환
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _check_chat_template_support(self):
|
def _check_chat_template_support(self):
|
||||||
"""
|
"""
|
||||||
주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다.
|
주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Chat 템플릿 지원 여부 (True 또는 False).
|
bool: Chat 템플릿 지원 여부 (True 또는 False).
|
||||||
"""
|
"""
|
||||||
@ -193,7 +174,7 @@ class TritonPythonModel:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}")
|
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _get_input_value(self, request, input_name: str, default=None):
|
def _get_input_value(self, request, input_name: str, default=None):
|
||||||
"""
|
"""
|
||||||
@ -211,9 +192,9 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
if tensor_value is None:
|
if tensor_value is None:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
return self._np_decoder(tensor_value.as_numpy()[0])
|
return self._np_decoder(tensor_value.as_numpy()[0])
|
||||||
|
|
||||||
def _np_decoder(self, obj):
|
def _np_decoder(self, obj):
|
||||||
"""
|
"""
|
||||||
NumPy 객체의 데이터 타입을 확인하고 Python 기본 타입으로 변환합니다.
|
NumPy 객체의 데이터 타입을 확인하고 Python 기본 타입으로 변환합니다.
|
||||||
@ -240,4 +221,4 @@ class TritonPythonModel:
|
|||||||
`finalize` 함수를 구현하는 것은 선택 사항입니다. 이 함수를 통해 모델은
|
`finalize` 함수를 구현하는 것은 선택 사항입니다. 이 함수를 통해 모델은
|
||||||
종료 전에 필요한 모든 정리 작업을 수행할 수 있습니다.
|
종료 전에 필요한 모든 정리 작업을 수행할 수 있습니다.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@ -1,6 +1,3 @@
|
|||||||
test.txt
|
test.txt
|
||||||
test.txt
|
test.txt
|
||||||
test.txt
|
test.txt
|
||||||
test.txt
|
|
||||||
test.txt
|
|
||||||
|
|
||||||
@ -1,7 +1,6 @@
|
|||||||
# Triton backend to use
|
|
||||||
name: "gemma-3-1b-it"
|
name: "gemma-3-1b-it"
|
||||||
backend: "python"
|
|
||||||
max_batch_size: 0
|
max_batch_size: 0
|
||||||
|
backend: "python"
|
||||||
|
|
||||||
# Triton should expect as input a single string
|
# Triton should expect as input a single string
|
||||||
# input of variable length named 'text_input'
|
# input of variable length named 'text_input'
|
||||||
@ -64,7 +63,7 @@ input [
|
|||||||
|
|
||||||
# Triton should expect to respond with a single string
|
# Triton should expect to respond with a single string
|
||||||
# output of variable length named 'text_output'
|
# output of variable length named 'text_output'
|
||||||
output [
|
output [
|
||||||
{
|
{
|
||||||
name: "text_output"
|
name: "text_output"
|
||||||
data_type: TYPE_STRING
|
data_type: TYPE_STRING
|
||||||
@ -72,7 +71,6 @@ output [
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
parameters: [
|
parameters: [
|
||||||
{
|
{
|
||||||
key: "enable_inference_trace",
|
key: "enable_inference_trace",
|
||||||
@ -85,5 +83,4 @@ instance_group [
|
|||||||
kind: KIND_AUTO,
|
kind: KIND_AUTO,
|
||||||
count: 1
|
count: 1
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user