create by CHEETAH
This commit is contained in:
parent
b77a93033e
commit
d39bb15e63
84
model.py
Normal file
84
model.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
def initialize(self, args):
|
||||||
|
"""모델 초기화. Triton이 서버 시작 시 실행."""
|
||||||
|
self.logger = pb_utils.Logger
|
||||||
|
|
||||||
|
model_repository = args["model_repository"]
|
||||||
|
model_name = args["model_name"]
|
||||||
|
model_path = f"{model_repository}/{model_name}"
|
||||||
|
|
||||||
|
self.model_config = json.loads(args["model_config"])
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
local_files_only=True,
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.log_info(f"\n### {model_name} 모델 초기화 완료")
|
||||||
|
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""Triton이 호출하는 Inference 실행 함수."""
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
input_name = self.model_config.get("input")[0]['name']
|
||||||
|
output_name = self.model_config.get("output")[0]['name']
|
||||||
|
|
||||||
|
for request in requests:
|
||||||
|
# Triton 입력 파싱
|
||||||
|
input_tensor = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy()[0]
|
||||||
|
input_text = input_tensor.decode('utf-8')
|
||||||
|
self.logger.log_info(f"### INPUT_TEXT: {input_text}")
|
||||||
|
|
||||||
|
# 토크나이징
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
f"### 질문: {input_text}\n\n### 답변:",
|
||||||
|
return_tensors="pt").to(device=self.model.device)
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"].to(device=self.model.device)
|
||||||
|
attention_mask = inputs["attention_mask"].to(device=self.model.device)
|
||||||
|
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 모델 추론
|
||||||
|
gened = self.model.generate(
|
||||||
|
generation_config=generation_config,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pad_token_id=2,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 생성된 텍스트 디코딩
|
||||||
|
answer = self.tokenizer.decode(gened[0])
|
||||||
|
self.logger.log_info(f"### MODEL ANSWER:\n{answer}")
|
||||||
|
|
||||||
|
# 답변 내용 후처리
|
||||||
|
output = self.post_process(answer)
|
||||||
|
self.logger.log_info(f"### OUTPUT_TEXT: {output}")
|
||||||
|
|
||||||
|
# Triton에 텐서로 반환
|
||||||
|
output_tensor = pb_utils.Tensor(output_name, np.array(output.encode('utf-8'), dtype=np.bytes_))
|
||||||
|
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def post_process(self, text):
|
||||||
|
try:
|
||||||
|
return str(text.split("### 답변:")[1].split("### 질문:")[0].strip())
|
||||||
|
except IndexError:
|
||||||
|
return text
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""서버 종료 시 정리 (옵션)."""
|
||||||
|
pass
|
||||||
Loading…
Reference in New Issue
Block a user