From bd4e98af69d517474dfefb1197caeb7ad0f5bcf0 Mon Sep 17 00:00:00 2001 From: groupuser Date: Mon, 14 Apr 2025 04:22:25 +0000 Subject: [PATCH] create by CHEETAH --- 1/model.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 1/model.py diff --git a/1/model.py b/1/model.py new file mode 100644 index 0000000..7449b7d --- /dev/null +++ b/1/model.py @@ -0,0 +1,84 @@ +import triton_python_backend_utils as pb_utils +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +import numpy as np +import json + +class TritonPythonModel: + def initialize(self, args): + """모델 초기화. Triton이 서버 시작 시 실행.""" + self.logger = pb_utils.Logger + + model_repository = args["model_repository"] + model_name = args["model_name"] + model_path = f"{model_repository}/{model_name}" + + self.model_config = json.loads(args["model_config"]) + + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + local_files_only=True, + device_map="auto" + ) + + self.logger.log_info(f"\n### {model_name} 모델 초기화 완료") + + + def execute(self, requests): + """Triton이 호출하는 Inference 실행 함수.""" + responses = [] + + input_name = self.model_config.get("input")[0]['name'] + output_name = self.model_config.get("output")[0]['name'] + + for request in requests: + # Triton 입력 파싱 + input_tensor = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy()[0] + input_text = input_tensor.decode('utf-8') + self.logger.log_info(f"### INPUT_TEXT: {input_text}") + + # 토크나이징 + inputs = self.tokenizer( + f"### 질문: {input_text}\n\n### 답변:", + return_tensors="pt").to(device=self.model.device) + + input_ids = inputs["input_ids"].to(device=self.model.device) + attention_mask = inputs["attention_mask"].to(device=self.model.device) + + generation_config = GenerationConfig( + max_new_tokens=256, + ) + + # 모델 추론 + gened = self.model.generate( + generation_config=generation_config, + input_ids=input_ids, + attention_mask=attention_mask, + pad_token_id=2, + repetition_penalty=1.1, + ) + + # 생성된 텍스트 디코딩 + answer = self.tokenizer.decode(gened[0]) + self.logger.log_info(f"### MODEL ANSWER:\n{answer}") + + # 답변 내용 후처리 + output = self.post_process(answer) + self.logger.log_info(f"### OUTPUT_TEXT: {output}") + + # Triton에 텐서로 반환 + output_tensor = pb_utils.Tensor(output_name, np.array(output.encode('utf-8'), dtype=np.bytes_)) + responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor])) + + return responses + + def post_process(self, text): + try: + return str(text.split("### 답변:")[1].split("### 질문:")[0].strip()) + except IndexError: + return text + + def finalize(self): + """서버 종료 시 정리 (옵션).""" + pass \ No newline at end of file