Update 1/model.py
This commit is contained in:
parent
536131678e
commit
e177542aca
413
1/model.py
413
1/model.py
@ -1,262 +1,238 @@
|
|||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
from typing import List, Dict, Any, Union, Tuple
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
GenerationConfig,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
from peft import PeftModel, PeftConfig
|
from peft import PeftModel, PeftConfig
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
def initialize(self, args):
|
def initialize(self, args: Dict[str, str]):
|
||||||
"""
|
"""
|
||||||
모델이 로드될 때 딱 한 번만 호출됩니다.
|
모델 초기화: 설정 로드, 로거 설정, 모델 및 토크나이저 로드
|
||||||
`initialize` 함수를 구현하는 것은 선택 사항입니다. 이 함수를 통해 모델은
|
|
||||||
이 모델과 관련된 모든 상태를 초기화할 수 있습니다.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
args : dict
|
|
||||||
Both keys and values are strings. The dictionary keys and values are:
|
|
||||||
* model_config: A JSON string containing the model configuration
|
|
||||||
* model_instance_kind: A string containing model instance kind
|
|
||||||
* model_instance_device_id: A string containing model instance device
|
|
||||||
ID
|
|
||||||
* model_repository: Model repository path
|
|
||||||
* model_version: Model version
|
|
||||||
* model_name: Model name
|
|
||||||
"""
|
"""
|
||||||
self.logger = pb_utils.Logger
|
self.logger = pb_utils.Logger
|
||||||
|
|
||||||
self.model_config = json.loads(args["model_config"])
|
self.model_config = json.loads(args["model_config"])
|
||||||
|
|
||||||
self.model_name = args["model_name"]
|
self.model_name = args["model_name"]
|
||||||
self.base_model_path = self._get_config_parameter("base_model_path")
|
|
||||||
self.is_adapter_model = self._get_config_parameter("is_adapter_model").strip().lower() == "true"
|
|
||||||
self.adapter_model_path = self._get_config_parameter("adapter_model_path")
|
|
||||||
self.quantization = self._get_config_parameter("quantization")
|
|
||||||
|
|
||||||
self.logger.log_info(f"base_model_path: {self.base_model_path}")
|
# 설정 파라미터 로드
|
||||||
self.logger.log_info(f"is_adapter_model: {self.is_adapter_model}")
|
self.base_model_path = self._get_config_param("base_model_path")
|
||||||
self.logger.log_info(f"adapter_model_path: {self.adapter_model_path}")
|
self.is_adapter_model = self._get_config_param("is_adapter_model", "false").lower() == "true"
|
||||||
self.logger.log_info(f"quantization: {self.quantization}")
|
self.adapter_model_path = self._get_config_param("adapter_model_path")
|
||||||
|
self.quantization = self._get_config_param("quantization")
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
self.load_model()
|
# 로그 출력
|
||||||
|
self.logger.log_info(f"================ {self.model_name} Setup ================")
|
||||||
|
self.logger.log_info(f"Base Model: {self.base_model_path}")
|
||||||
|
self.logger.log_info(f"Adapter Mode: {self.is_adapter_model} ({self.adapter_model_path})")
|
||||||
|
self.logger.log_info(f"Quantization: {self.quantization}")
|
||||||
|
self.logger.log_info(f"Device: {self.device}")
|
||||||
|
|
||||||
def load_model(self):
|
self._load_model_and_tokenizer()
|
||||||
"""
|
self.logger.log_info(f"Model initialized successfully.")
|
||||||
Load model
|
|
||||||
"""
|
|
||||||
self.bnb_config = None
|
|
||||||
torch_dtype = torch.float16 # 기본 dtype (필요시 bfloat16 등으로 조절)
|
|
||||||
|
|
||||||
# 양자화 옵션 체크
|
def _load_model_and_tokenizer(self):
|
||||||
|
"""모델과 토크나이저를 로드하고 설정합니다."""
|
||||||
|
# 1. Quantization 설정
|
||||||
|
bnb_config = self._get_bnb_config()
|
||||||
|
|
||||||
|
# 2. Base Model 로드
|
||||||
|
# Adapter 모델인 경우 Config에서 Base 경로를 덮어쓸 수 있음
|
||||||
|
load_path = self.base_model_path
|
||||||
|
if self.is_adapter_model:
|
||||||
|
peft_config = PeftConfig.from_pretrained(self.adapter_model_path)
|
||||||
|
load_path = peft_config.base_model_name_or_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
load_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
device_map="auto",
|
||||||
|
local_files_only=True,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log_error(f"Failed to load base model: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# 3. Adapter 병합 (필요 시)
|
||||||
|
if self.is_adapter_model:
|
||||||
|
self.model = PeftModel.from_pretrained(self.model, self.adapter_model_path)
|
||||||
|
|
||||||
|
# 추론 모드 설정
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# 4. Tokenizer 로드
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(load_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Pad Token 설정 (없을 경우 EOS로 대체)
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||||
|
self.logger.log_info("Pad token was None. Set to EOS token.")
|
||||||
|
|
||||||
|
# Chat Template 지원 여부 확인
|
||||||
|
self.supports_chat_template = (
|
||||||
|
hasattr(self.tokenizer, "chat_template") and
|
||||||
|
self.tokenizer.chat_template is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_bnb_config(self) -> Union[BitsAndBytesConfig, None]:
|
||||||
|
"""양자화 설정 객체를 반환합니다."""
|
||||||
if self.quantization == "int4":
|
if self.quantization == "int4":
|
||||||
self.bnb_config = BitsAndBytesConfig(
|
return BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
bnb_4bit_compute_dtype=torch_dtype
|
bnb_4bit_compute_dtype=torch.float16
|
||||||
)
|
)
|
||||||
elif self.quantization == "int8":
|
elif self.quantization == "int8":
|
||||||
self.bnb_config = BitsAndBytesConfig(
|
return BitsAndBytesConfig(
|
||||||
load_in_8bit=True,
|
load_in_8bit=True,
|
||||||
llm_int8_threshold=6.0,
|
llm_int8_threshold=6.0,
|
||||||
llm_int8_has_fp16_weight=True
|
llm_int8_has_fp16_weight=True
|
||||||
)
|
)
|
||||||
else:
|
return None
|
||||||
self.bnb_config = None
|
|
||||||
|
|
||||||
if self.is_adapter_model:
|
|
||||||
# 어댑터 설정 정보 로드
|
|
||||||
peft_config = PeftConfig.from_pretrained(self.adapter_model_path)
|
|
||||||
self.base_model_path = peft_config.base_model_name_or_path
|
|
||||||
|
|
||||||
# base 모델 로드
|
|
||||||
base_model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
peft_config.base_model_name_or_path,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
quantization_config=self.bnb_config if self.bnb_config else None,
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
|
|
||||||
# adapter 모델 로드 (base 모델 위에 덧씌움)
|
|
||||||
self.model = PeftModel.from_pretrained(base_model, self.adapter_model_path)
|
|
||||||
else:
|
|
||||||
# 일반 모델인 경우 로드
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=self.base_model_path,
|
|
||||||
local_files_only=True,
|
|
||||||
quantization_config=self.bnb_config if self.bnb_config else None,
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenizer는 base model의 tokenizer 사용
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
|
|
||||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
||||||
self.supports_chat_template = self._check_chat_template_support()
|
|
||||||
|
|
||||||
self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료")
|
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""
|
"""Triton Inference Request 처리 메인 루프"""
|
||||||
Triton이 각 추론 요청에 대해 호출하는 실행 함수입니다.
|
|
||||||
"""
|
|
||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
# 각 추론 요청을 순회하며 처리합니다.
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
# Triton 입력 파싱
|
|
||||||
input_text = self._get_input_value(request, "text_input")
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
conversation = ""
|
|
||||||
input_token_length = 0 # 입력 토큰 길이를 저장할 변수
|
|
||||||
|
|
||||||
# 입력 텍스트가 JSON 형식의 대화 기록인지 확인합니다.
|
|
||||||
try:
|
try:
|
||||||
conversation = json.loads(input_text)
|
# 1. 입력 데이터 파싱
|
||||||
is_chat = True
|
input_data, is_chat = self._parse_input(request)
|
||||||
self.logger.log_info(f"입력 conversation 출력:\n{conversation}")
|
|
||||||
except:
|
# 2. Generation Config 생성
|
||||||
# JSON 파싱에 실패하면 일반 텍스트로 처리합니다.
|
gen_config = self._create_generation_config(request)
|
||||||
text = input_text
|
|
||||||
is_chat = False
|
|
||||||
self.logger.log_info(f"입력 text 출력:\n{text}")
|
|
||||||
|
|
||||||
# 입력 텍스트를 토큰화합니다.
|
# 3. 토크나이징
|
||||||
if self.supports_chat_template and is_chat:
|
inputs = self._tokenize(input_data, is_chat)
|
||||||
self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.")
|
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
|
||||||
conversation,
|
|
||||||
tokenize=True,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
return_dict=True
|
|
||||||
).to(device=self.model.device)
|
|
||||||
else:
|
|
||||||
self.logger.log_info(f"입력 텍스트를 토큰화합니다.")
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
text,
|
|
||||||
return_tensors="pt").to(device=self.model.device)
|
|
||||||
|
|
||||||
input_ids = inputs["input_ids"]
|
# 4. 모델 추론 (Generate)
|
||||||
attention_mask = inputs["attention_mask"]
|
output_text = self._generate(inputs, gen_config)
|
||||||
input_token_length = inputs["input_ids"].shape[-1]
|
|
||||||
|
|
||||||
|
# 5. 응답 생성
|
||||||
|
responses.append(self._create_response(output_text))
|
||||||
|
|
||||||
# 언어 모델을 사용하여 텍스트를 생성합니다.
|
except Exception as e:
|
||||||
gened = self.model.generate(
|
self.logger.log_error(f"Error during execution: {e}")
|
||||||
input_ids=input_ids,
|
# 에러 발생 시 빈 문자열 또는 에러 메시지 반환 (클라이언트 처리에 따라 변경 가능)
|
||||||
attention_mask=attention_mask,
|
err_tensor = pb_utils.Tensor("text_output", np.array([str(e).encode('utf-8')], dtype=np.bytes_))
|
||||||
generation_config=self._process_generation_config(request),
|
responses.append(pb_utils.InferenceResponse(output_tensors=[err_tensor]))
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 생성된 토큰 시퀀스를 텍스트로 디코딩하고 입력 텍스트는 제외합니다.
|
|
||||||
generated_tokens = gened[0][input_token_length:] # 입력 토큰 이후부터 슬라이싱
|
|
||||||
gened_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
||||||
self.logger.log_info(f"모델이 생성한 토큰 시퀀스 (입력 텍스트 제외):\n{gened_text}")
|
|
||||||
|
|
||||||
output = gened_text.strip()
|
|
||||||
|
|
||||||
# 생성된 텍스트를 Triton 출력 텐서로 변환합니다.
|
|
||||||
output_tensor = pb_utils.Tensor("text_output", np.array(output.encode('utf-8'), dtype=np.bytes_))
|
|
||||||
|
|
||||||
# 응답 객체를 생성하고 출력 텐서를 추가합니다.
|
|
||||||
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
|
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def _process_generation_config(self, request):
|
def _parse_input(self, request) -> Tuple[Union[str, List[Dict]], bool]:
|
||||||
"""
|
"""입력 텐서를 파싱하여 텍스트 또는 대화 목록과 타입(채팅 여부)을 반환합니다."""
|
||||||
추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다.
|
input_text = self._get_input_scalar(request, "text_input")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# JSON 형식의 대화 기록인지 시도
|
||||||
|
conversation = json.loads(input_text)
|
||||||
|
# 리스트 형태여야 채팅 히스토리로 간주
|
||||||
|
if isinstance(conversation, list):
|
||||||
|
return conversation, True
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return input_text, False
|
||||||
|
|
||||||
Args:
|
def _tokenize(self, input_data, is_chat: bool):
|
||||||
request (pb_utils.InferenceRequest): Triton 추론 요청 객체.
|
"""입력 데이터를 토큰화하여 모델 입력 텐서를 반환합니다."""
|
||||||
|
if self.supports_chat_template and is_chat:
|
||||||
|
return self.tokenizer.apply_chat_template(
|
||||||
|
input_data,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True
|
||||||
|
).to(self.device)
|
||||||
|
else:
|
||||||
|
# 일반 텍스트인 경우
|
||||||
|
if is_chat: # Chat 형식이지만 템플릿 미지원 시 문자열 변환 시도
|
||||||
|
input_data = str(input_data)
|
||||||
|
|
||||||
|
return self.tokenizer(
|
||||||
|
input_data,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
Returns:
|
def _generate(self, inputs, gen_config: GenerationConfig) -> str:
|
||||||
transformers.GenerationConfig: GenerationConfig 객체.
|
"""모델 생성 로직 수행 및 디코딩"""
|
||||||
"""
|
input_ids = inputs["input_ids"]
|
||||||
max_length = self._get_input_value(request, "max_length", default=20)
|
input_len = input_ids.shape[-1]
|
||||||
max_new_tokens = self._get_input_value(request, "max_new_tokens")
|
|
||||||
temperature = self._get_input_value(request, "temperature")
|
|
||||||
do_sample = self._get_input_value(request, "do_sample")
|
|
||||||
top_k = self._get_input_value(request, "top_k")
|
|
||||||
top_p = self._get_input_value(request, "top_p")
|
|
||||||
repetition_penalty = self._get_input_value(request, "repetition_penalty")
|
|
||||||
stream = self._get_input_value(request, "stream")
|
|
||||||
|
|
||||||
generation_config = GenerationConfig(
|
with torch.no_grad():
|
||||||
max_length=max_length,
|
outputs = self.model.generate(
|
||||||
max_new_tokens=max_new_tokens,
|
**inputs,
|
||||||
temperature=temperature,
|
generation_config=gen_config,
|
||||||
do_sample=do_sample,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
top_k=top_k,
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
top_p=top_p,
|
)
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
stream=stream,
|
# 입력 토큰을 제외하고 생성된 토큰만 슬라이싱
|
||||||
|
generated_tokens = outputs[0][input_len:]
|
||||||
|
decoded_output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return decoded_output.strip()
|
||||||
|
|
||||||
|
def _create_generation_config(self, request) -> GenerationConfig:
|
||||||
|
"""Request에서 파라미터를 추출하여 GenerationConfig 객체 생성"""
|
||||||
|
# 기본값 설정 및 입력값 추출 Helper
|
||||||
|
def get_param(name, default=None, cast_type=None):
|
||||||
|
val = self._get_input_scalar(request, name, default)
|
||||||
|
if val is not None and cast_type:
|
||||||
|
return cast_type(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
return GenerationConfig(
|
||||||
|
max_length=get_param("max_length", 1024, int), # max_length 기본값은 넉넉하게
|
||||||
|
max_new_tokens=get_param("max_new_tokens", 256, int),
|
||||||
|
temperature=get_param("temperature", 1.0, float),
|
||||||
|
do_sample=get_param("do_sample", False, bool),
|
||||||
|
top_k=get_param("top_k", 50, int),
|
||||||
|
top_p=get_param("top_p", 1.0, float),
|
||||||
|
repetition_penalty=get_param("repetition_penalty", 1.0, float),
|
||||||
|
# stream=get_param("stream", False, bool) # Python Backend에서 Stream은 별도 구현 필요
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.log_info(f"추론 요청 GenerationConfig:\n{generation_config}")
|
def _create_response(self, output_text: str):
|
||||||
|
"""생성된 텍스트를 Triton Response 객체로 변환"""
|
||||||
|
output_tensor = pb_utils.Tensor(
|
||||||
|
"text_output",
|
||||||
|
np.array([output_text.encode('utf-8')], dtype=np.bytes_)
|
||||||
|
)
|
||||||
|
return pb_utils.InferenceResponse(output_tensors=[output_tensor])
|
||||||
|
|
||||||
return generation_config
|
def _get_config_param(self, key: str, default: str = None) -> str:
|
||||||
|
"""config.pbtxt 파라미터 조회 Helper"""
|
||||||
|
params = self.model_config.get('parameters', {})
|
||||||
|
if key in params:
|
||||||
|
return params[key].get('string_value', default)
|
||||||
|
return default
|
||||||
|
|
||||||
def _get_config_parameter(self, parameter_name):
|
def _get_input_scalar(self, request, name: str, default=None):
|
||||||
"""
|
"""입력 텐서에서 스칼라 값을 추출하는 Helper"""
|
||||||
모델 설정(config.pbtxt)에서 특정 파라미터의 문자열 값을 가져옵니다.
|
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
||||||
|
if tensor is None:
|
||||||
Args:
|
|
||||||
parameter_name (str): 가져올 파라미터의 이름.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str or None: 파라미터의 'string_value' 또는 해당 파라미터가 없거나 'string_value' 키가 없는 경우 None.
|
|
||||||
"""
|
|
||||||
self.parameters = self.model_config.get('parameters', {})
|
|
||||||
parameter_dict = self.parameters.get(parameter_name)
|
|
||||||
|
|
||||||
if isinstance(parameter_dict, dict) and 'string_value' in parameter_dict:
|
|
||||||
return parameter_dict['string_value']
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _check_chat_template_support(self):
|
|
||||||
"""
|
|
||||||
주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Chat 템플릿 지원 여부 (True 또는 False).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None:
|
|
||||||
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저는 Chat 템플릿을 지원합니다.")
|
|
||||||
self.logger.log_info("Chat 템플릿 내용:")
|
|
||||||
self.logger.log_info(self.tokenizer.chat_template)
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저는 Chat 템플릿을 직접적으로 지원하지 않거나, Chat 템플릿 정보가 없습니다.")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_input_value(self, request, input_name: str, default=None):
|
|
||||||
"""
|
|
||||||
Triton 추론 요청에서 특정 이름의 입력 텐서 값을 가져옵니다.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (pb_utils.InferenceRequest): Triton 추론 요청 객체.
|
|
||||||
input_name (str): 가져올 입력 텐서의 이름.
|
|
||||||
default (any, optional): 입력 텐서가 없을 경우 반환할 기본값. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
any: 디코딩된 입력 텐서의 값. 텐서가 없으면 기본값을 반환합니다.
|
|
||||||
"""
|
|
||||||
tensor_value = pb_utils.get_input_tensor_by_name(request, input_name)
|
|
||||||
|
|
||||||
if tensor_value is None:
|
|
||||||
return default
|
return default
|
||||||
|
|
||||||
return self._np_decoder(tensor_value.as_numpy()[0])
|
#value = tensor.as_numpy().item() # item()을 사용하여 스칼라 값 추출
|
||||||
|
value = self._np_decoder(tensor.as_numpy()[0])
|
||||||
|
|
||||||
|
# 바이트 타입 디코딩
|
||||||
|
# if isinstance(value, bytes):
|
||||||
|
# return value.decode('utf-8')
|
||||||
|
return value
|
||||||
|
|
||||||
def _np_decoder(self, obj):
|
def _np_decoder(self, obj):
|
||||||
"""
|
"""
|
||||||
@ -279,9 +255,8 @@ class TritonPythonModel:
|
|||||||
return bool(obj)
|
return bool(obj)
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""
|
"""리소스 정리"""
|
||||||
모델 실행이 완료된 후 Triton 서버가 종료될 때 한 번 호출되는 함수입니다.
|
self.logger.log_info(f"Finalizing model {self.model_name}")
|
||||||
`finalize` 함수를 구현하는 것은 선택 사항입니다. 이 함수를 통해 모델은
|
self.model = None
|
||||||
종료 전에 필요한 모든 정리 작업을 수행할 수 있습니다.
|
self.tokenizer = None
|
||||||
"""
|
torch.cuda.empty_cache()
|
||||||
pass
|
|
||||||
Loading…
Reference in New Issue
Block a user