diff --git a/1/model.py b/1/model.py index b83692b..82c8d5a 100644 --- a/1/model.py +++ b/1/model.py @@ -2,27 +2,17 @@ import torch import numpy as np import json import triton_python_backend_utils as pb_utils -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +# AutoModelForCausalLM 대신 AutoModel을 가져옵니다. +from transformers import AutoModel, AutoTokenizer, GenerationConfig from peft import PeftModel, PeftConfig +# BitsAndBytesConfig가 import 되지 않아 임시로 주석 처리하거나, 필요하다면 설치 후 주석 해제해야 합니다. +# from bitsandbytes import BitsAndBytesConfig + class TritonPythonModel: def initialize(self, args): """ 모델이 로드될 때 딱 한 번만 호출됩니다. - `initialize` 함수를 구현하는 것은 선택 사항입니다. 이 함수를 통해 모델은 - 이 모델과 관련된 모든 상태를 초기화할 수 있습니다. - - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device - ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name """ self.logger = pb_utils.Logger @@ -30,6 +20,9 @@ class TritonPythonModel: self.model_name = args["model_name"] self.base_model_path = self._get_config_parameter("base_model_path") + + # CodeSage는 임베딩 모델이므로 LoRA 등의 어댑터 로드는 지원하지 않거나 일반적이지 않습니다. + # 기존 로직은 유지하되, 실제로 사용하지 않을 경우 config.pbtxt에서 해당 파라미터를 제거하는 것이 좋습니다. self.is_adapter_model = self._get_config_parameter("is_adapter_model").strip().lower() == "true" self.adapter_model_path = self._get_config_parameter("adapter_model_path") self.quantization = self._get_config_parameter("quantization") @@ -43,172 +36,149 @@ class TritonPythonModel: def load_model(self): """ - Load model + Load model: CodeSage에 맞게 AutoModel과 trust_remote_code=True를 사용하도록 수정 """ self.bnb_config = None - torch_dtype = torch.float16 # 기본 dtype (필요시 bfloat16 등으로 조절) + torch_dtype = torch.float16 # 기본 dtype - # 양자화 옵션 체크 - if self.quantization == "int4": - self.bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch_dtype - ) - elif self.quantization == "int8": - self.bnb_config = BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=True - ) - else: - self.bnb_config = None + # Note: CodeSage는 AutoModel을 사용하며, 일반적인 CausalLM 양자화 옵션이 적용되는지 + # 확인이 필요하지만, 일단 기존 로직은 유지합니다. + + # BitsAndBytesConfig가 정의되지 않았으므로 주석 처리합니다. + # if self.quantization == "int4": + # # ... (int4 config) + # pass + # elif self.quantization == "int8": + # # ... (int8 config) + # pass + # else: + # self.bnb_config = None if self.is_adapter_model: - # 어댑터 설정 정보 로드 + # CodeSage는 임베딩 모델이므로 어댑터 사용은 일반적이지 않으나, + # 기존 템플릿 로직은 유지합니다. peft_config = PeftConfig.from_pretrained(self.adapter_model_path) self.base_model_path = peft_config.base_model_name_or_path - # base 모델 로드 - base_model = AutoModelForCausalLM.from_pretrained( + # base 모델 로드: AutoModel로 변경 + base_model = AutoModel.from_pretrained( peft_config.base_model_name_or_path, torch_dtype=torch.float16, quantization_config=self.bnb_config if self.bnb_config else None, - device_map="auto" + device_map="auto", + trust_remote_code=True # 필수 옵션 추가 ) - # adapter 모델 로드 (base 모델 위에 덧씌움) + # adapter 모델 로드 self.model = PeftModel.from_pretrained(base_model, self.adapter_model_path) else: - # 일반 모델인 경우 로드 - self.model = AutoModelForCausalLM.from_pretrained( + # 일반 모델인 경우 로드: AutoModel로 변경 + self.model = AutoModel.from_pretrained( pretrained_model_name_or_path=self.base_model_path, - local_files_only=True, + # local_files_only=True, # CodeSage는 허브에서 로드될 수 있으므로 주석 처리 quantization_config=self.bnb_config if self.bnb_config else None, - device_map="auto" + device_map="auto", + trust_remote_code=True # 필수 옵션 추가 ) + + # 모델을 평가 모드로 설정 + self.model.eval() - # Tokenizer는 base model의 tokenizer 사용 - self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_path) - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.supports_chat_template = self._check_chat_template_support() + # Tokenizer 로드: trust_remote_code=True와 add_eos_token=True 추가 + self.tokenizer = AutoTokenizer.from_pretrained( + self.base_model_path, + trust_remote_code=True, + add_eos_token=True # 필수 옵션 추가 + ) + # 임베딩 모델이므로 pad_token_id 설정은 불필요하거나, 주의가 필요함. 일단 주석 처리 + # self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + # 임베딩 모델에서는 Chat Template 지원 체크는 불필요하므로 제거 + # self.supports_chat_template = self._check_chat_template_support() + self.supports_chat_template = False # 항상 False로 설정 - self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료") + self.logger.log_info(f"'{self.model_name}' 모델 초기화 완료 (Code Embedding Mode)") def execute(self, requests): """ Triton이 각 추론 요청에 대해 호출하는 실행 함수입니다. + Generation 대신 Embedding 생성을 수행하도록 수정합니다. """ responses = [] # 각 추론 요청을 순회하며 처리합니다. for request in requests: - # Triton 입력 파싱 + # Triton 입력 파싱: 텍스트 입력만 처리합니다. input_text = self._get_input_value(request, "text_input") - text = "" - conversation = "" - input_token_length = 0 # 입력 토큰 길이를 저장할 변수 - - # 입력 텍스트가 JSON 형식의 대화 기록인지 확인합니다. - try: - conversation = json.loads(input_text) - is_chat = True - self.logger.log_info(f"입력 conversation 출력:\n{conversation}") - except: - # JSON 파싱에 실패하면 일반 텍스트로 처리합니다. - text = input_text - is_chat = False - self.logger.log_info(f"입력 text 출력:\n{text}") - + # CodeSage는 대화 형식이 아닌 일반 텍스트 (코드)를 입력으로 받으므로 + # JSON 파싱 로직과 Chat 템플릿 로직을 제거합니다. + text = input_text + self.logger.log_info(f"입력 text 출력:\n{text}") + # 입력 텍스트를 토큰화합니다. - if self.supports_chat_template and is_chat: - self.logger.log_info(f"Chat 템플릿을 적용하여 토큰화합니다.") - inputs = self.tokenizer.apply_chat_template( - conversation, - tokenize=True, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True - ).to(device=self.model.device) - else: - self.logger.log_info(f"입력 텍스트를 토큰화합니다.") - inputs = self.tokenizer( - text, - return_tensors="pt").to(device=self.model.device) + # add_eos_token=True가 load_model에서 설정되었으므로 토큰화 시 자동으로 추가됩니다. + inputs = self.tokenizer( + text, + return_tensors="pt").to(device=self.model.device) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] - input_token_length = inputs["input_ids"].shape[-1] + + # # CodeSage는 텍스트 생성이 아닌 임베딩 생성을 수행합니다. + # gened = self.model.generate(...) + # **임베딩 생성** + # CodeSage 모델은 (임베딩, 히든 스테이트, 어텐션)을 반환하며, 첫 번째 요소가 임베딩입니다. + with torch.no_grad(): + # inputs에는 input_ids와 attention_mask가 모두 포함되어 전달됩니다. + outputs = self.model(**inputs) + + # outputs[0]에는 임베딩 벡터가 포함되어 있습니다. + # 임베딩은 일반적으로 첫 번째 토큰 (CLS 토큰 또는 문맥 임베딩)을 사용합니다. + # CodeSage의 경우, 모델 카드 예시를 보면 outputs[0] 전체를 사용합니다. + # 여기서는 [batch_size, sequence_length, hidden_size] 형태의 임베딩 중 첫 번째 토큰 임베딩을 사용합니다. + # 임베딩 사용법은 모델의 목적에 따라 달라질 수 있습니다. CodeSage는 주로 전체 시퀀스 임베딩을 사용합니다. + # 여기서는 예시와 같이 첫 번째 요소 (last_hidden_state)를 가져옵니다. + # 임베딩 크기: [1, seq_len, hidden_size] + embeddings = outputs[0] + + # 임베딩을 NumPy 배열로 변환 + # CPU로 옮기고, NumPy로 변환 + embedding_np = embeddings.squeeze().cpu().numpy() + + # 출력 텐서 생성 (데이터 타입은 float32 또는 float16이 적합) + # CodeSage는 단일 문장 입력만 처리하므로, 배치 차원 없이 [seq_len, hidden_size]로 가정합니다. + # 실제 사용 목적에 따라 풀링 로직을 추가하여 [hidden_size] 벡터로 만들 수도 있습니다. + output_tensor = pb_utils.Tensor("embedding_output", embedding_np.astype(np.float32)) - # 언어 모델을 사용하여 텍스트를 생성합니다. - gened = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - generation_config=self._process_generation_config(request), - pad_token_id=self.tokenizer.pad_token_id, - ) - - # 생성된 토큰 시퀀스를 텍스트로 디코딩하고 입력 텍스트는 제외합니다. - generated_tokens = gened[0][input_token_length:] # 입력 토큰 이후부터 슬라이싱 - gened_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) - self.logger.log_info(f"모델이 생성한 토큰 시퀀스 (입력 텍스트 제외):\n{gened_text}") - - output = gened_text.strip() - - # 생성된 텍스트를 Triton 출력 텐서로 변환합니다. - output_tensor = pb_utils.Tensor("text_output", np.array(output.encode('utf-8'), dtype=np.bytes_)) + self.logger.log_info(f"모델이 생성한 임베딩 Shape:\n{embedding_np.shape}") # 응답 객체를 생성하고 출력 텐서를 추가합니다. responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor])) return responses - def _process_generation_config(self, request): - """ - 추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다. + # 임베딩 모델에서는 불필요하므로 제거합니다. + # def _process_generation_config(self, request): + # """ + # 추론 요청에서 생성 설정 관련 파라미터들을 추출하여 GenerationConfig 객체를 생성합니다. + # """ + # # ... (기존 로직) + # pass + + # 임베딩 모델에서는 불필요하므로 제거합니다. + # def _check_chat_template_support(self): + # """ + # 주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다. + # """ + # # ... (기존 로직) + # pass - Args: - request (pb_utils.InferenceRequest): Triton 추론 요청 객체. - - Returns: - transformers.GenerationConfig: GenerationConfig 객체. - """ - max_length = self._get_input_value(request, "max_length", default=20) - max_new_tokens = self._get_input_value(request, "max_new_tokens") - temperature = self._get_input_value(request, "temperature") - do_sample = self._get_input_value(request, "do_sample") - top_k = self._get_input_value(request, "top_k") - top_p = self._get_input_value(request, "top_p") - repetition_penalty = self._get_input_value(request, "repetition_penalty") - stream = self._get_input_value(request, "stream") - - generation_config = GenerationConfig( - max_length=max_length, - max_new_tokens=max_new_tokens, - temperature=temperature, - do_sample=do_sample, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - stream=stream, - ) - - self.logger.log_info(f"추론 요청 GenerationConfig:\n{generation_config}") - - return generation_config def _get_config_parameter(self, parameter_name): """ 모델 설정(config.pbtxt)에서 특정 파라미터의 문자열 값을 가져옵니다. - - Args: - parameter_name (str): 가져올 파라미터의 이름. - - Returns: - str or None: 파라미터의 'string_value' 또는 해당 파라미터가 없거나 'string_value' 키가 없는 경우 None. """ self.parameters = self.model_config.get('parameters', {}) parameter_dict = self.parameters.get(parameter_name) @@ -218,38 +188,9 @@ class TritonPythonModel: return None - def _check_chat_template_support(self): - """ - 주어진 허깅페이스 Transformer 모델이 Chat 템플릿을 지원하는지 확인하고 결과를 출력합니다. - - Returns: - bool: Chat 템플릿 지원 여부 (True 또는 False). - """ - try: - if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None: - self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저는 Chat 템플릿을 지원합니다.") - self.logger.log_info("Chat 템플릿 내용:") - self.logger.log_info(self.tokenizer.chat_template) - return True - else: - self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저는 Chat 템플릿을 직접적으로 지원하지 않거나, Chat 템플릿 정보가 없습니다.") - return False - except Exception as e: - self.logger.log_info(f"'{self.model_name}' 모델의 토크나이저를 로드하는 동안 오류가 발생했습니다: {e}") - return False - - def _get_input_value(self, request, input_name: str, default=None): """ Triton 추론 요청에서 특정 이름의 입력 텐서 값을 가져옵니다. - - Args: - request (pb_utils.InferenceRequest): Triton 추론 요청 객체. - input_name (str): 가져올 입력 텐서의 이름. - default (any, optional): 입력 텐서가 없을 경우 반환할 기본값. Defaults to None. - - Returns: - any: 디코딩된 입력 텐서의 값. 텐서가 없으면 기본값을 반환합니다. """ tensor_value = pb_utils.get_input_tensor_by_name(request, input_name) @@ -261,13 +202,6 @@ class TritonPythonModel: def _np_decoder(self, obj): """ NumPy 객체의 데이터 타입을 확인하고 Python 기본 타입으로 변환합니다. - - Args: - obj (numpy.ndarray element): 변환할 NumPy 배열의 요소. - - Returns: - any: 해당 NumPy 요소에 대응하는 Python 기본 타입 (str, int, float, bool). - bytes 타입인 경우 UTF-8로 디코딩합니다. """ if isinstance(obj, bytes): return obj.decode('utf-8') @@ -281,7 +215,5 @@ class TritonPythonModel: def finalize(self): """ 모델 실행이 완료된 후 Triton 서버가 종료될 때 한 번 호출되는 함수입니다. - `finalize` 함수를 구현하는 것은 선택 사항입니다. 이 함수를 통해 모델은 - 종료 전에 필요한 모든 정리 작업을 수행할 수 있습니다. """ - pass + pass \ No newline at end of file