Update 1/model.py

response with random string value
This commit is contained in:
cheetahadmin 2025-09-08 07:56:36 +00:00
parent 5a7db3cc84
commit 1287c02c91

@ -1,6 +1,10 @@
import torch import torch
import numpy as np import numpy as np
import json import json
import random
import string
import triton_python_backend_utils as pb_utils import triton_python_backend_utils as pb_utils
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig from peft import PeftModel, PeftConfig
@ -105,58 +109,61 @@ class TritonPythonModel:
# 각 추론 요청을 순회하며 처리합니다. # 각 추론 요청을 순회하며 처리합니다.
for request in requests: for request in requests:
# Triton 입력 파싱 # Triton 입력 파싱
input_text = self._get_input_value(request, "text_input") # input_text = self._get_input_value(request, "text_input")
text = "" # text = ""
conversation = "" # conversation = ""
input_token_length = 0 # 입력 토큰 길이를 저장할 변수 # input_token_length = 0 # 입력 토큰 길이를 저장할 변수
# 입력 텍스트가 JSON 형식의 대화 기록인지 확인합니다. # 입력 텍스트가 JSON 형식의 대화 기록인지 확인합니다.
try: # try:
conversation = json.loads(input_text) # conversation = json.loads(input_text)
is_chat = True # is_chat = True
self.logger.log_info(f"입력 conversation 출력:\n{conversation}") # self.logger.log_info(f"입력 conversation 출력:\n{conversation}")
except: # except:
# JSON 파싱에 실패하면 일반 텍스트로 처리합니다. # # JSON 파싱에 실패하면 일반 텍스트로 처리합니다.
text = input_text # text = input_text
is_chat = False # is_chat = False
self.logger.log_info(f"입력 text 출력:\n{text}") # self.logger.log_info(f"입력 text 출력:\n{text}")
# 입력 텍스트를 토큰화합니다. # # 입력 텍스트를 토큰화합니다.
if self.supports_chat_template and is_chat: # if self.supports_chat_template and is_chat:
self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.") # self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.")
inputs = self.tokenizer.apply_chat_template( # inputs = self.tokenizer.apply_chat_template(
conversation, # conversation,
tokenize=True, # tokenize=True,
add_generation_prompt=True, # add_generation_prompt=True,
return_tensors="pt", # return_tensors="pt",
return_dict=True # return_dict=True
).to(device=self.model.device) # ).to(device=self.model.device)
else: # else:
self.logger.log_info(f"입력 텍스트를 토큰화합니다.") # self.logger.log_info(f"입력 텍스트를 토큰화합니다.")
inputs = self.tokenizer( # inputs = self.tokenizer(
text, # text,
return_tensors="pt").to(device=self.model.device) # return_tensors="pt").to(device=self.model.device)
input_ids = inputs["input_ids"] # input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"] # attention_mask = inputs["attention_mask"]
input_token_length = inputs["input_ids"].shape[-1] # input_token_length = inputs["input_ids"].shape[-1]
# 언어 모델을 사용하여 텍스트를 생성합니다. # # 언어 모델을 사용하여 텍스트를 생성합니다.
gened = self.model.generate( # gened = self.model.generate(
input_ids=input_ids, # input_ids=input_ids,
attention_mask=attention_mask, # attention_mask=attention_mask,
generation_config=self._process_generation_config(request), # generation_config=self._process_generation_config(request),
pad_token_id=self.tokenizer.pad_token_id, # pad_token_id=self.tokenizer.pad_token_id,
) # )
# 생성된 토큰 시퀀스를 텍스트로 디코딩하고 입력 텍스트는 제외합니다. # # 생성된 토큰 시퀀스를 텍스트로 디코딩하고 입력 텍스트는 제외합니다.
generated_tokens = gened[0][input_token_length:] # 입력 토큰 이후부터 슬라이싱 # generated_tokens = gened[0][input_token_length:] # 입력 토큰 이후부터 슬라이싱
gened_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) # gened_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.logger.log_info(f"모델이 생성한 토큰 시퀀스 (입력 텍스트 제외):\n{gened_text}") # self.logger.log_info(f"모델이 생성한 토큰 시퀀스 (입력 텍스트 제외):\n{gened_text}")
output = gened_text.strip() random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
# output = gened_text.strip()
output = random_string
# 생성된 텍스트를 Triton 출력 텐서로 변환합니다. # 생성된 텍스트를 Triton 출력 텐서로 변환합니다.
output_tensor = pb_utils.Tensor("text_output", np.array(output.encode('utf-8'), dtype=np.bytes_)) output_tensor = pb_utils.Tensor("text_output", np.array(output.encode('utf-8'), dtype=np.bytes_))
@ -238,7 +245,6 @@ class TritonPythonModel:
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}") self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}")
return False return False
def _get_input_value(self, request, input_name: str, default=None): def _get_input_value(self, request, input_name: str, default=None):
""" """
Triton 추론 요청에서 특정 이름의 입력 텐서 값을 가져옵니다. Triton 추론 요청에서 특정 이름의 입력 텐서 값을 가져옵니다.