84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
import triton_python_backend_utils as pb_utils
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
import numpy as np
|
|
import json
|
|
|
|
class TritonPythonModel:
|
|
def initialize(self, args):
|
|
"""모델 초기화. Triton이 서버 시작 시 실행."""
|
|
self.logger = pb_utils.Logger
|
|
|
|
model_repository = args["model_repository"]
|
|
model_name = args["model_name"]
|
|
model_path = f"{model_repository}/{model_name}"
|
|
|
|
self.model_config = json.loads(args["model_config"])
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
local_files_only=True,
|
|
device_map="auto"
|
|
)
|
|
|
|
self.logger.log_info(f"\n### {model_name} 모델 초기화 완료")
|
|
|
|
|
|
def execute(self, requests):
|
|
"""Triton이 호출하는 Inference 실행 함수."""
|
|
responses = []
|
|
|
|
input_name = self.model_config.get("input")[0]['name']
|
|
output_name = self.model_config.get("output")[0]['name']
|
|
|
|
for request in requests:
|
|
# Triton 입력 파싱
|
|
input_tensor = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy()[0]
|
|
input_text = input_tensor.decode('utf-8')
|
|
self.logger.log_info(f"### INPUT_TEXT: {input_text}")
|
|
|
|
# 토크나이징
|
|
inputs = self.tokenizer(
|
|
f"### 질문: {input_text}\n\n### 답변:",
|
|
return_tensors="pt").to(device=self.model.device)
|
|
|
|
input_ids = inputs["input_ids"].to(device=self.model.device)
|
|
attention_mask = inputs["attention_mask"].to(device=self.model.device)
|
|
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=256,
|
|
)
|
|
|
|
# 모델 추론
|
|
gened = self.model.generate(
|
|
generation_config=generation_config,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
pad_token_id=2,
|
|
repetition_penalty=1.1,
|
|
)
|
|
|
|
# 생성된 텍스트 디코딩
|
|
answer = self.tokenizer.decode(gened[0])
|
|
self.logger.log_info(f"### MODEL ANSWER:\n{answer}")
|
|
|
|
# 답변 내용 후처리
|
|
output = self.post_process(answer)
|
|
self.logger.log_info(f"### OUTPUT_TEXT: {output}")
|
|
|
|
# Triton에 텐서로 반환
|
|
output_tensor = pb_utils.Tensor(output_name, np.array(output.encode('utf-8'), dtype=np.bytes_))
|
|
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
|
|
|
return responses
|
|
|
|
def post_process(self, text):
|
|
try:
|
|
return str(text.split("### 답변:")[1].split("### 질문:")[0].strip())
|
|
except IndexError:
|
|
return text
|
|
|
|
def finalize(self):
|
|
"""서버 종료 시 정리 (옵션)."""
|
|
pass |