Update 1/model.py
response with random string value
This commit is contained in:
parent
5a7db3cc84
commit
1287c02c91
94
1/model.py
94
1/model.py
@ -1,6 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from peft import PeftModel, PeftConfig
|
||||
@ -105,58 +109,61 @@ class TritonPythonModel:
|
||||
# 각 추론 요청을 순회하며 처리합니다.
|
||||
for request in requests:
|
||||
# Triton 입력 파싱
|
||||
input_text = self._get_input_value(request, "text_input")
|
||||
# input_text = self._get_input_value(request, "text_input")
|
||||
|
||||
text = ""
|
||||
conversation = ""
|
||||
input_token_length = 0 # 입력 토큰 길이를 저장할 변수
|
||||
# text = ""
|
||||
# conversation = ""
|
||||
# input_token_length = 0 # 입력 토큰 길이를 저장할 변수
|
||||
|
||||
# 입력 텍스트가 JSON 형식의 대화 기록인지 확인합니다.
|
||||
try:
|
||||
conversation = json.loads(input_text)
|
||||
is_chat = True
|
||||
self.logger.log_info(f"입력 conversation 출력:\n{conversation}")
|
||||
except:
|
||||
# JSON 파싱에 실패하면 일반 텍스트로 처리합니다.
|
||||
text = input_text
|
||||
is_chat = False
|
||||
self.logger.log_info(f"입력 text 출력:\n{text}")
|
||||
# try:
|
||||
# conversation = json.loads(input_text)
|
||||
# is_chat = True
|
||||
# self.logger.log_info(f"입력 conversation 출력:\n{conversation}")
|
||||
# except:
|
||||
# # JSON 파싱에 실패하면 일반 텍스트로 처리합니다.
|
||||
# text = input_text
|
||||
# is_chat = False
|
||||
# self.logger.log_info(f"입력 text 출력:\n{text}")
|
||||
|
||||
# 입력 텍스트를 토큰화합니다.
|
||||
if self.supports_chat_template and is_chat:
|
||||
self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.")
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True
|
||||
).to(device=self.model.device)
|
||||
else:
|
||||
self.logger.log_info(f"입력 텍스트를 토큰화합니다.")
|
||||
inputs = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt").to(device=self.model.device)
|
||||
# # 입력 텍스트를 토큰화합니다.
|
||||
# if self.supports_chat_template and is_chat:
|
||||
# self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.")
|
||||
# inputs = self.tokenizer.apply_chat_template(
|
||||
# conversation,
|
||||
# tokenize=True,
|
||||
# add_generation_prompt=True,
|
||||
# return_tensors="pt",
|
||||
# return_dict=True
|
||||
# ).to(device=self.model.device)
|
||||
# else:
|
||||
# self.logger.log_info(f"입력 텍스트를 토큰화합니다.")
|
||||
# inputs = self.tokenizer(
|
||||
# text,
|
||||
# return_tensors="pt").to(device=self.model.device)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
input_token_length = inputs["input_ids"].shape[-1]
|
||||
# input_ids = inputs["input_ids"]
|
||||
# attention_mask = inputs["attention_mask"]
|
||||
# input_token_length = inputs["input_ids"].shape[-1]
|
||||
|
||||
|
||||
# 언어 모델을 사용하여 텍스트를 생성합니다.
|
||||
gened = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=self._process_generation_config(request),
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
# # 언어 모델을 사용하여 텍스트를 생성합니다.
|
||||
# gened = self.model.generate(
|
||||
# input_ids=input_ids,
|
||||
# attention_mask=attention_mask,
|
||||
# generation_config=self._process_generation_config(request),
|
||||
# pad_token_id=self.tokenizer.pad_token_id,
|
||||
# )
|
||||
|
||||
# 생성된 토큰 시퀀스를 텍스트로 디코딩하고 입력 텍스트는 제외합니다.
|
||||
generated_tokens = gened[0][input_token_length:] # 입력 토큰 이후부터 슬라이싱
|
||||
gened_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
self.logger.log_info(f"모델이 생성한 토큰 시퀀스 (입력 텍스트 제외):\n{gened_text}")
|
||||
# # 생성된 토큰 시퀀스를 텍스트로 디코딩하고 입력 텍스트는 제외합니다.
|
||||
# generated_tokens = gened[0][input_token_length:] # 입력 토큰 이후부터 슬라이싱
|
||||
# gened_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
# self.logger.log_info(f"모델이 생성한 토큰 시퀀스 (입력 텍스트 제외):\n{gened_text}")
|
||||
|
||||
output = gened_text.strip()
|
||||
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
|
||||
# output = gened_text.strip()
|
||||
output = random_string
|
||||
|
||||
# 생성된 텍스트를 Triton 출력 텐서로 변환합니다.
|
||||
output_tensor = pb_utils.Tensor("text_output", np.array(output.encode('utf-8'), dtype=np.bytes_))
|
||||
@ -238,7 +245,6 @@ class TritonPythonModel:
|
||||
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _get_input_value(self, request, input_name: str, default=None):
|
||||
"""
|
||||
Triton 추론 요청에서 특정 이름의 입력 텐서 값을 가져옵니다.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user