create by CHEETAH

This commit is contained in:
groupuser 2025-03-28 04:31:46 +00:00
parent b77a93033e
commit d39bb15e63

84
model.py Normal file

@ -0,0 +1,84 @@
import triton_python_backend_utils as pb_utils
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import numpy as np
import json
class TritonPythonModel:
def initialize(self, args):
"""모델 초기화. Triton이 서버 시작 시 실행."""
self.logger = pb_utils.Logger
model_repository = args["model_repository"]
model_name = args["model_name"]
model_path = f"{model_repository}/{model_name}"
self.model_config = json.loads(args["model_config"])
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
local_files_only=True,
device_map="auto"
)
self.logger.log_info(f"\n### {model_name} 모델 초기화 완료")
def execute(self, requests):
"""Triton이 호출하는 Inference 실행 함수."""
responses = []
input_name = self.model_config.get("input")[0]['name']
output_name = self.model_config.get("output")[0]['name']
for request in requests:
# Triton 입력 파싱
input_tensor = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy()[0]
input_text = input_tensor.decode('utf-8')
self.logger.log_info(f"### INPUT_TEXT: {input_text}")
# 토크나이징
inputs = self.tokenizer(
f"### 질문: {input_text}\n\n### 답변:",
return_tensors="pt").to(device=self.model.device)
input_ids = inputs["input_ids"].to(device=self.model.device)
attention_mask = inputs["attention_mask"].to(device=self.model.device)
generation_config = GenerationConfig(
max_new_tokens=256,
)
# 모델 추론
gened = self.model.generate(
generation_config=generation_config,
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=2,
repetition_penalty=1.1,
)
# 생성된 텍스트 디코딩
answer = self.tokenizer.decode(gened[0])
self.logger.log_info(f"### MODEL ANSWER:\n{answer}")
# 답변 내용 후처리
output = self.post_process(answer)
self.logger.log_info(f"### OUTPUT_TEXT: {output}")
# Triton에 텐서로 반환
output_tensor = pb_utils.Tensor(output_name, np.array(output.encode('utf-8'), dtype=np.bytes_))
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
return responses
def post_process(self, text):
try:
return str(text.split("### 답변:")[1].split("### 질문:")[0].strip())
except IndexError:
return text
def finalize(self):
"""서버 종료 시 정리 (옵션)."""
pass