gemma-2b/1/model.py
2024-09-09 04:24:56 +00:00

168 lines
5.6 KiB
Python
Executable File

import os
os.environ[
"HF_HOME"
] = "/opt/tritonserver/model_repository/nlp_models/hf_cache"
import json
import numpy as np
import torch
import transformers
import triton_python_backend_utils as pb_utils
# 추가
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig, TextStreamer, pipeline
import time
import sys
from datetime import datetime
torch.cuda.empty_cache()
class TritonPythonModel:
def initialize(self, args):
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
self.model_params = self.model_config.get("parameters", {})
default_hf_model = "/models/gemma-2b/gemma-2b"
default_max_gen_length = "15"
# Check for user-specified model name in model config parameters
hf_model = self.model_params.get("huggingface_model", {}).get(
"string_value", default_hf_model
)
# Check for user-specified max length in model config parameters
self.max_output_length = int(
self.model_params.get("max_output_length", {}).get(
"string_value", default_max_gen_length
)
)
self.base_model = hf_model
self.logger.log_info(f"Max sequence length: {self.max_output_length}")
self.logger.log_info(f"Loading HuggingFace model: {hf_model}...")
self.config_model()
self.model_eval()
self.model_compile()
self.logger.log_info(f"Initialized...")
def config_model(self):
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
self.tokenzier = AutoTokenizer.from_pretrained(self.base_model)
self.tokenzier.pad_token = self.tokenzier.eos_token
self.logger.log_info("tokenizer loaded")
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model,
#quantization_config=self.bnb_config,
local_files_only=True,
trust_remote_code=True,
device_map="auto"
)
self.model.enable_input_require_grads()
self.logger.log_info(f"base.model.device : {self.model.device}")
self.logger.log_info("base model loaded")
def model_eval(self):
self.logger.log_info(f"...model eval start")
self.model.eval()
self.model.config.use_cache = True
self.logger.log_info(f"...model eval end")
def model_compile(self):
self.logger.log_info(f"...model compile start")
if torch.__version__ >= "2" and sys.platform != "win32":
self.model = torch.compile(self.model)
self.logger.log_info(f"...model compiled!")
self.logger.log_info(f"...model compile end")
def execute(self, requests):
self.logger.log_info("### inferenect start")
temperature = 0.2
top_p = 0.7
instruction = ''
max_new_tokens = 15
stream_output = False
responses = []
answer = []
for request in requests:
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
input = input_tensor.as_numpy()[0].decode("utf-8")
question = [input]
self.logger.log_info("### Receive Time: {}\n".format(datetime.now()))
self.logger.log_info("### Question : {}\n".format(input.encode('utf-8').strip()))
self.logger.log_info(f"model device type is : {self.model.device}")
start = time.time()
generation_config = GenerationConfig(
temperature=0.15,
top_k=40,
do_sample=True,
eos_token_id=2,
early_stopping=True,
max_new_tokens=15
)
inputs = self.tokenzier(
f"### 질문: {input}\n\n### 답변:",
return_tensors='pt',
padding=True,
return_token_type_ids=False).to(device=self.model.device)
input_ids = inputs["input_ids"].to(device=self.model.device)
attention_mask = inputs["attention_mask"].to(device=self.model.device)
self.logger.log_info(f"inputs: {inputs}")
gened = self.model.generate(
generation_config=generation_config,
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=2
)
answer = self.tokenzier.decode(gened[0])
end = time.time()
self.logger.log_info(f"decoded_output answer : {answer}")
self.logger.log_info(f"### generate elapsed time is {end - start}")
output = self.post_process(answer)
self.logger.log_info(f"output type is: {type(output)}")
self.logger.log_info(f"### fianl output : {output}")
output_tensor_0 = pb_utils.Tensor("text_output", np.array(output.encode('utf-8'), dtype=np.bytes_))
self.logger.log_info(f"type : {type(output_tensor_0)}, value: {output_tensor_0}")
response = pb_utils.InferenceResponse(output_tensors=[output_tensor_0])
responses.append(response)
self.logger.log_info("### inferenect end")
return responses
def post_process(self, text):
return str(text.split("### 답변:")[1].split("### 질문:")[0].strip())
def finalize(self):
print("Cleaning up...")