create by CHEETAH
This commit is contained in:
parent
b9753c31c3
commit
a7e6791e85
84
1/model.py
Normal file
84
1/model.py
Normal file
@ -0,0 +1,84 @@
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
class TritonPythonModel:
|
||||
def initialize(self, args):
|
||||
"""모델 초기화. Triton이 서버 시작 시 실행."""
|
||||
self.logger = pb_utils.Logger
|
||||
|
||||
model_repository = args["model_repository"]
|
||||
model_name = args["model_name"]
|
||||
model_path = f"{model_repository}/{model_name}"
|
||||
|
||||
self.model_config = json.loads(args["model_config"])
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
self.logger.log_info(f"\n### {model_name} 모델 초기화 완료")
|
||||
|
||||
|
||||
def execute(self, requests):
|
||||
"""Triton이 호출하는 Inference 실행 함수."""
|
||||
responses = []
|
||||
|
||||
input_name = self.model_config.get("input")[0]['name']
|
||||
output_name = self.model_config.get("output")[0]['name']
|
||||
|
||||
for request in requests:
|
||||
# Triton 입력 파싱
|
||||
input_tensor = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy()[0]
|
||||
input_text = input_tensor.decode('utf-8')
|
||||
self.logger.log_info(f"### INPUT_TEXT: {input_text}")
|
||||
|
||||
# 토크나이징
|
||||
inputs = self.tokenizer(
|
||||
f"### 질문: {input_text}\n\n### 답변:",
|
||||
return_tensors="pt").to(device=self.model.device)
|
||||
|
||||
input_ids = inputs["input_ids"].to(device=self.model.device)
|
||||
attention_mask = inputs["attention_mask"].to(device=self.model.device)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
# 모델 추론
|
||||
gened = self.model.generate(
|
||||
generation_config=generation_config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pad_token_id=2,
|
||||
repetition_penalty=1.1,
|
||||
)
|
||||
|
||||
# 생성된 텍스트 디코딩
|
||||
answer = self.tokenizer.decode(gened[0])
|
||||
self.logger.log_info(f"### MODEL ANSWER:\n{answer}")
|
||||
|
||||
# 답변 내용 후처리
|
||||
output = self.post_process(answer)
|
||||
self.logger.log_info(f"### OUTPUT_TEXT: {output}")
|
||||
|
||||
# Triton에 텐서로 반환
|
||||
output_tensor = pb_utils.Tensor(output_name, np.array(output.encode('utf-8'), dtype=np.bytes_))
|
||||
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
||||
|
||||
return responses
|
||||
|
||||
def post_process(self, text):
|
||||
try:
|
||||
return str(text.split("### 답변:")[1].split("### 질문:")[0].strip())
|
||||
except IndexError:
|
||||
return text
|
||||
|
||||
def finalize(self):
|
||||
"""서버 종료 시 정리 (옵션)."""
|
||||
pass
|
||||
Loading…
Reference in New Issue
Block a user