Update model.py

This commit is contained in:
groupuser 2025-11-24 00:57:49 +00:00
parent 023ef1fe92
commit 3ba54c6ce7

@ -1,213 +1,241 @@
import triton_python_backend_utils as pb_utils """
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig [Transformer-LLM 백엔드 가이드]
import numpy as np
파일은 NVIDIA Triton Server에서 Hugging Face `AutoModelForCausalLM` 기반 모델을 손쉽게 배포하기 위해 제공되는 커스텀 Python 백엔드 템플릿입니다.
1. 모델 호환성
- Hugging Face의 `AutoModelForCausalLM` 클래스와 호환되는 모든 Causal Language Model을 지원합니다.
- [확인] 배포할 모델 `config.json` `architectures` 항목이 `...ForCausalLM` 형식인지 확인.
2. 토크나이저 호환성
- `AutoTokenizer` 호환되는 토크나이저를 지원하며, 모델과 동일한 경로에서 자동으로 로드됩니다.
3. 커스터마이징 안내
- 템플릿은 범용적인 사용을 위해 작성되었습니다.
- 특정 모델의 동작 방식이나 예외 처리가 필요한 경우, 파일(`model.py`) 설정 파일(`config.pbtxt`) 직접 수정하여 사용하시기 바랍니다.
"""
import json import json
import torch
import numpy as np
import triton_python_backend_utils as pb_utils
import uuid
from typing import List, Dict, Any, Union, Tuple
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
BitsAndBytesConfig,
)
from peft import PeftModel, PeftConfig
class TritonPythonModel: class TritonPythonModel:
def initialize(self, args): def initialize(self, args: Dict[str, str]):
""" """
모델이 로드될 번만 호출됩니다. 모델 초기화: 설정 로드, 로거 설정, 모델 토크나이저 로드
`initialize` 함수를 구현하는 것은 선택 사항입니다. 함수를 통해 모델은
모델과 관련된 모든 상태를 초기화할 있습니다.
""" """
self.logger = pb_utils.Logger self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"]) self.model_config = json.loads(args["model_config"])
self.model_name = args["model_name"] self.model_name = args["model_name"]
self.model_path = self._get_config_parameter("model_path")
self.enable_inference_trace = self._get_config_parameter("enable_inference_trace")
self.logger.log_info(f"model_path: {self.model_path}") # 설정 파라미터 로드
self.base_model_path = self._get_config_param("base_model_path")
self.is_adapter_model = self._get_config_param("is_adapter_model", "false").lower() == "true"
self.adapter_model_path = self._get_config_param("adapter_model_path")
self.quantization = self._get_config_param("quantization")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Hugging Face Transformers 라이브러리에서 사전 학습된 토크나이저를 로드합니다. # 설정 로그 출력
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) self.logger.log_info(f"================ {self.model_name} Setup ================")
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.logger.log_info(f"Base Model: {self.base_model_path}")
self.supports_chat_template = self._check_chat_template_support() self.logger.log_info(f"Adapter Mode: {self.is_adapter_model} ({self.adapter_model_path})")
self.logger.log_info(f"Quantization: {self.quantization}")
self.logger.log_info(f"Device: {self.device}")
# Hugging Face Transformers 라이브러리에서 사전 학습된 언어 모델을 로드합니다. self._load_model_and_tokenizer()
self.model = AutoModelForCausalLM.from_pretrained( self.logger.log_info(f"Model initialized successfully.")
pretrained_model_name_or_path=self.model_path,
local_files_only=True, def _load_model_and_tokenizer(self):
device_map="auto" """모델과 토크나이저를 로드하고 설정합니다."""
# 1. Quantization 설정
bnb_config = self._get_bnb_config()
# 2. Base Model 로드
load_path = self.base_model_path
if self.is_adapter_model:
peft_config = PeftConfig.from_pretrained(self.adapter_model_path)
load_path = peft_config.base_model_name_or_path
try:
self.model = AutoModelForCausalLM.from_pretrained(
load_path,
torch_dtype=torch.float16,
quantization_config=bnb_config,
device_map="auto",
local_files_only=True,
trust_remote_code=True
)
except Exception as e:
self.logger.log_error(f"Failed to load base model: {e}")
raise e
# 3. Adapter 병합 (필요 시)
if self.is_adapter_model:
self.model = PeftModel.from_pretrained(self.model, self.adapter_model_path)
self.model.eval()
# 4. Tokenizer 로드
self.tokenizer = AutoTokenizer.from_pretrained(load_path, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.logger.log_info("Pad token was None. Set to EOS token.")
self.supports_chat_template = (
hasattr(self.tokenizer, "chat_template") and
self.tokenizer.chat_template is not None
) )
self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료") self.logger.log_info(f"Supports Chat Template: {self.supports_chat_template}")
if self.supports_chat_template:
self.logger.log_info(f"Chat Template Content:\n{self.tokenizer.chat_template}")
def _get_bnb_config(self) -> Union[BitsAndBytesConfig, None]:
if self.quantization == "int4":
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
elif self.quantization == "int8":
return BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=True
)
return None
def execute(self, requests): def execute(self, requests):
""" """Triton Inference Request 처리 메인 루프"""
Triton이 추론 요청에 대해 호출하는 실행 함수입니다.
"""
responses = [] responses = []
# 각 추론 요청을 순회하며 처리합니다.
for request in requests: for request in requests:
# Triton 입력 파싱 # [ID 생성 로직] - 로그 추적용으로 유지 (Response에는 포함 X)
input_text = self._get_input_value(request, "text_input") request_id = request.request_id()
if not request_id:
request_id = str(uuid.uuid4())
text = ""
conversation = ""
input_token_length = 0 # 입력 토큰 길이를 저장할 변수
# 입력 텍스트가 JSON 형식의 대화 기록인지 확인합니다.
try: try:
conversation = json.loads(input_text) # 1. 입력 데이터 파싱
is_chat = True input_data, is_chat = self._parse_input(request)
self.logger.log_info(f"입력 conversation 출력:\n{conversation}")
except:
# JSON 파싱에 실패하면 일반 텍스트로 처리합니다.
text = input_text
is_chat = False
self.logger.log_info(f"입력 text 출력:\n{text}")
# 입력 텍스트를 토큰화합니다. # [LOGGING] Request ID 포함하여 로그 출력
if self.supports_chat_template and is_chat: log_input_str = json.dumps(input_data, ensure_ascii=False) if isinstance(input_data, (list, dict)) else str(input_data)
self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.") self.logger.log_info(f"\n[RID: {request_id}] >>> [{'CHAT' if is_chat else 'TEXT'}][Input]: {log_input_str}")
inputs = self.tokenizer.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(device=self.model.device)
else:
self.logger.log_info(f"입력 텍스트를 토큰화합니다.")
inputs = self.tokenizer(
text,
return_tensors="pt").to(device=self.model.device)
input_ids = inputs["input_ids"] # 2. Generation Config 생성
attention_mask = inputs["attention_mask"] gen_config = self._create_generation_config(request)
input_token_length = inputs["input_ids"].shape[-1]
# 3. 토크나이징
inputs = self._tokenize(input_data, is_chat)
# 언어 모델을 사용하여 텍스트를 생성합니다. # 4. 모델 추론 (Generate)
gened = self.model.generate( output_text = self._generate(inputs, gen_config)
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=self._process_generation_config(request),
pad_token_id=self.tokenizer.pad_token_id,
)
# 생성된 토큰 시퀀스를 텍스트로 디코딩하고 입력 텍스트는 제외합니다. # [LOGGING] Request ID 포함하여 결과 출력
generated_tokens = gened[0][input_token_length:] # 입력 토큰 이후부터 슬라이싱 self.logger.log_info(f"\n[RID: {request_id}] <<< [Output]: {output_text}")
gened_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.logger.log_info(f"모델이 생성한 토큰 시퀀스 (입력 텍스트 제외):\n{gened_text}")
output = gened_text.strip() # 5. 응답 생성
responses.append(self._create_response(output_text, request_id))
# 생성된 텍스트를 Triton 출력 텐서로 변환합니다. except Exception as e:
output_tensor = pb_utils.Tensor("text_output", np.array(output.encode('utf-8'), dtype=np.bytes_)) self.logger.log_error(f"[RID: {request_id}] Error during execution: {e}")
err_tensor = pb_utils.Tensor("text_output", np.array([str(e).encode('utf-8')], dtype=np.bytes_))
# 응답 객체를 생성하고 출력 텐서를 추가합니다. responses.append(pb_utils.InferenceResponse(output_tensors=[err_tensor]))
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
return responses return responses
def _process_generation_config(self, request): def _parse_input(self, request) -> Tuple[Union[str, List[Dict]], bool]:
""" input_text = self._get_input_scalar(request, "text_input")
추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다. try:
conversation = json.loads(input_text)
if isinstance(conversation, list):
return conversation, True
except (json.JSONDecodeError, TypeError):
pass
return input_text, False
Args: def _tokenize(self, input_data, is_chat: bool):
request (pb_utils.InferenceRequest): Triton 추론 요청 객체. if self.supports_chat_template and is_chat:
return self.tokenizer.apply_chat_template(
input_data,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(self.device)
else:
if is_chat:
input_data = str(input_data)
return self.tokenizer(input_data, return_tensors="pt").to(self.device)
Returns: def _generate(self, inputs, gen_config: GenerationConfig) -> str:
transformers.GenerationConfig: GenerationConfig 객체. input_ids = inputs["input_ids"]
""" input_len = input_ids.shape[-1]
max_length = self._get_input_value(request, "max_length", default=20)
max_new_tokens = self._get_input_value(request, "max_new_tokens")
temperature = self._get_input_value(request, "temperature")
do_sample = self._get_input_value(request, "do_sample")
top_k = self._get_input_value(request, "top_k")
top_p = self._get_input_value(request, "top_p")
repetition_penalty = self._get_input_value(request, "repetition_penalty")
stream = self._get_input_value(request, "stream")
generation_config = GenerationConfig( with torch.no_grad():
max_length=max_length, outputs = self.model.generate(
max_new_tokens=max_new_tokens, **inputs,
temperature=temperature, generation_config=gen_config,
do_sample=do_sample, pad_token_id=self.tokenizer.pad_token_id,
top_k=top_k, eos_token_id=self.tokenizer.eos_token_id
top_p=top_p, )
repetition_penalty=repetition_penalty,
stream=stream, generated_tokens = outputs[0][input_len:]
decoded_output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return decoded_output.strip()
def _create_generation_config(self, request) -> GenerationConfig:
def get_param(name, default=None, cast_type=None):
val = self._get_input_scalar(request, name, default)
if val is not None and cast_type:
return cast_type(val)
return val
return GenerationConfig(
max_length=get_param("max_length", 1024, int),
max_new_tokens=get_param("max_new_tokens", 256, int),
temperature=get_param("temperature", 1.0, float),
do_sample=get_param("do_sample", False, bool),
top_k=get_param("top_k", 50, int),
top_p=get_param("top_p", 1.0, float),
repetition_penalty=get_param("repetition_penalty", 1.0, float),
) )
self.logger.log_info(f"추론 요청 GenerationConfig:\n{generation_config}") def _create_response(self, output_text: str, request_id: str):
"""생성된 텍스트를 Triton Response 객체로 변환"""
output_tensor = pb_utils.Tensor(
"text_output",
np.array([output_text.encode('utf-8')], dtype=np.bytes_)
)
return pb_utils.InferenceResponse(output_tensors=[output_tensor])
return generation_config def _get_config_param(self, key: str, default: str = None) -> str:
params = self.model_config.get('parameters', {})
if key in params:
return params[key].get('string_value', default)
return default
def _get_config_parameter(self, parameter_name): def _get_input_scalar(self, request, name: str, default=None):
""" tensor = pb_utils.get_input_tensor_by_name(request, name)
모델 설정(config.pbtxt)에서 특정 파라미터의 문자열 값을 가져옵니다. if tensor is None:
Args:
parameter_name (str): 가져올 파라미터의 이름.
Returns:
str or None: 파라미터의 'string_value' 또는 해당 파라미터가 없거나 'string_value' 키가 없는 경우 None.
"""
self.parameters = self.model_config.get('parameters', {})
parameter_dict = self.parameters.get(parameter_name)
if isinstance(parameter_dict, dict) and 'string_value' in parameter_dict:
return parameter_dict['string_value']
return None
def _check_chat_template_support(self):
"""
주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다.
Returns:
bool: Chat 템플릿 지원 여부 (True 또는 False).
"""
try:
if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None:
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저는 Chat 템플릿을 지원합니다.")
self.logger.log_info("Chat 템플릿 내용:")
self.logger.log_info(self.tokenizer.chat_template)
return True
else:
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저는 Chat 템플릿을 직접적으로 지원하지 않거나, Chat 템플릿 정보가 없습니다.")
return False
except Exception as e:
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}")
return False
def _get_input_value(self, request, input_name: str, default=None):
"""
Triton 추론 요청에서 특정 이름의 입력 텐서 값을 가져옵니다.
Args:
request (pb_utils.InferenceRequest): Triton 추론 요청 객체.
input_name (str): 가져올 입력 텐서의 이름.
default (any, optional): 입력 텐서가 없을 경우 반환할 기본값. Defaults to None.
Returns:
any: 디코딩된 입력 텐서의 . 텐서가 없으면 기본값을 반환합니다.
"""
tensor_value = pb_utils.get_input_tensor_by_name(request, input_name)
if tensor_value is None:
return default return default
return self._np_decoder(tensor.as_numpy()[0])
return self._np_decoder(tensor_value.as_numpy()[0])
def _np_decoder(self, obj): def _np_decoder(self, obj):
"""
NumPy 객체의 데이터 타입을 확인하고 Python 기본 타입으로 변환합니다.
Args:
obj (numpy.ndarray element): 변환할 NumPy 배열의 요소.
Returns:
any: 해당 NumPy 요소에 대응하는 Python 기본 타입 (str, int, float, bool).
bytes 타입인 경우 UTF-8 디코딩합니다.
"""
if isinstance(obj, bytes): if isinstance(obj, bytes):
return obj.decode('utf-8') return obj.decode('utf-8')
if np.issubdtype(obj, np.integer): if np.issubdtype(obj, np.integer):
@ -218,9 +246,7 @@ class TritonPythonModel:
return bool(obj) return bool(obj)
def finalize(self): def finalize(self):
""" self.logger.log_info(f"Finalizing model {self.model_name}")
모델 실행이 완료된 Triton 서버가 종료될 호출되는 함수입니다. self.model = None
`finalize` 함수를 구현하는 것은 선택 사항입니다. 함수를 통해 모델은 self.tokenizer = None
종료 전에 필요한 모든 정리 작업을 수행할 있습니다. torch.cuda.empty_cache()
"""
pass