Compare commits
No commits in common. "main" and "refs/deployment/triton" have entirely different histories.
main
...
refs/deplo
36
.gitattributes
vendored
36
.gitattributes
vendored
@ -1,36 +0,0 @@
|
|||||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.model filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
||||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gguf filter=lfs diff=lfs merge=lfs -text
|
|
||||||
242
1/model.py
Normal file
242
1/model.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
import uuid
|
||||||
|
import transformers
|
||||||
|
from typing import List, Dict, Any, Union, Tuple
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
GenerationConfig,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
|
from peft import PeftModel, PeftConfig
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
def initialize(self, args: Dict[str, str]):
|
||||||
|
"""
|
||||||
|
모델 초기화: 라이브러리 버전 확인 및 모델/토크나이저 로드
|
||||||
|
"""
|
||||||
|
self.logger = pb_utils.Logger
|
||||||
|
self.model_config = json.loads(args["model_config"])
|
||||||
|
self.model_name = args["model_name"]
|
||||||
|
|
||||||
|
# 1. 라이브러리 버전 로그 추가
|
||||||
|
# GGUF 로드를 위해서는 최소 4.40.0 이상을 권장합니다.
|
||||||
|
transformers_version = transformers.__version__
|
||||||
|
self.logger.log_info(f"================ {self.model_name} Setup ================")
|
||||||
|
self.logger.log_info(f"Transformers Version: {transformers_version}")
|
||||||
|
self.logger.log_info(f"Torch Version: {torch.__version__}")
|
||||||
|
|
||||||
|
# 설정 파라미터 로드
|
||||||
|
self.base_model_path = self._get_config_param("base_model_path")
|
||||||
|
self.gguf_filename = self._get_config_param("gguf_filename")
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
self.logger.log_info(f"Base Model Path: {self.base_model_path}")
|
||||||
|
self.logger.log_info(f"GGUF Filename: {self.gguf_filename}")
|
||||||
|
self.logger.log_info(f"Device: {self.device}")
|
||||||
|
|
||||||
|
# 2. 모델 및 토크나이저 로드 실행
|
||||||
|
self._load_model_and_tokenizer()
|
||||||
|
self.logger.log_info(f"Model initialized successfully.")
|
||||||
|
|
||||||
|
def _load_model_and_tokenizer(self):
|
||||||
|
"""
|
||||||
|
config.pbtxt의 파라미터를 사용하여 GGUF 모델을 로드합니다.
|
||||||
|
Transformers 라이브러리가 GGUF를 읽어 fp16으로 역양자화합니다.
|
||||||
|
"""
|
||||||
|
# 1. config.pbtxt에서 설정값 읽기
|
||||||
|
load_path = self.base_model_path # /cheetah/input/model/groupuser/Qwen3-4B-Instruct-2507-mahjong-alpha
|
||||||
|
gguf_file = self._get_config_param("gguf_filename") # Qwen3-4B-Instruct-2507-mahjong-alpha.gguf
|
||||||
|
|
||||||
|
self.logger.log_info(f"Loading GGUF from: {load_path}/{gguf_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 2. Tokenizer 로드 (GGUF 파일 내의 토크나이저 메타데이터 참조)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
load_path,
|
||||||
|
gguf_file=gguf_file,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Model 로드 (GGUF -> PyTorch fp16 변환)
|
||||||
|
# 주의: GGUF 로드 시 bnb_config(int4/8)와 중복 사용은 불가능할 수 있습니다.
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
load_path,
|
||||||
|
gguf_file=gguf_file,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
local_files_only=True,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# 패딩 토큰 설정
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
self.supports_chat_template = (
|
||||||
|
hasattr(self.tokenizer, "chat_template") and
|
||||||
|
self.tokenizer.chat_template is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.log_info("GGUF Model and Tokenizer loaded successfully via Transformers.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log_error(f"Failed to load GGUF model: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _get_bnb_config(self) -> Union[BitsAndBytesConfig, None]:
|
||||||
|
if self.quantization == "int4":
|
||||||
|
return BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.float16
|
||||||
|
)
|
||||||
|
elif self.quantization == "int8":
|
||||||
|
return BitsAndBytesConfig(
|
||||||
|
load_in_8bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""Triton Inference Request 처리 메인 루프"""
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
for request in requests:
|
||||||
|
# [ID 생성 로직] - 로그 추적용으로 유지 (Response에는 포함 X)
|
||||||
|
request_id = request.request_id()
|
||||||
|
if not request_id:
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 입력 데이터 파싱
|
||||||
|
input_data, is_chat = self._parse_input(request)
|
||||||
|
|
||||||
|
# [LOGGING] Request ID 포함하여 로그 출력
|
||||||
|
log_input_str = json.dumps(input_data, ensure_ascii=False) if isinstance(input_data, (list, dict)) else str(input_data)
|
||||||
|
self.logger.log_info(f"\n[RID: {request_id}] >>> [{'CHAT' if is_chat else 'TEXT'}][Input]: {log_input_str}")
|
||||||
|
|
||||||
|
# 2. Generation Config 생성
|
||||||
|
gen_config = self._create_generation_config(request)
|
||||||
|
|
||||||
|
# 3. 토크나이징
|
||||||
|
inputs = self._tokenize(input_data, is_chat)
|
||||||
|
|
||||||
|
# 4. 모델 추론 (Generate)
|
||||||
|
output_text = self._generate(inputs, gen_config)
|
||||||
|
|
||||||
|
# [LOGGING] Request ID 포함하여 결과 출력
|
||||||
|
self.logger.log_info(f"\n[RID: {request_id}] <<< [Output]: {output_text}")
|
||||||
|
|
||||||
|
# 5. 응답 생성
|
||||||
|
responses.append(self._create_response(output_text, request_id))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log_error(f"[RID: {request_id}] Error during execution: {e}")
|
||||||
|
err_tensor = pb_utils.Tensor("text_output", np.array([str(e).encode('utf-8')], dtype=np.bytes_))
|
||||||
|
responses.append(pb_utils.InferenceResponse(output_tensors=[err_tensor]))
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _parse_input(self, request) -> Tuple[Union[str, List[Dict]], bool]:
|
||||||
|
input_text = self._get_input_scalar(request, "text_input")
|
||||||
|
try:
|
||||||
|
conversation = json.loads(input_text)
|
||||||
|
if isinstance(conversation, list):
|
||||||
|
return conversation, True
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
return input_text, False
|
||||||
|
|
||||||
|
def _tokenize(self, input_data, is_chat: bool):
|
||||||
|
if self.supports_chat_template and is_chat:
|
||||||
|
return self.tokenizer.apply_chat_template(
|
||||||
|
input_data,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True
|
||||||
|
).to(self.device)
|
||||||
|
else:
|
||||||
|
if is_chat:
|
||||||
|
input_data = str(input_data)
|
||||||
|
return self.tokenizer(input_data, return_tensors="pt").to(self.device)
|
||||||
|
|
||||||
|
def _generate(self, inputs, gen_config: GenerationConfig) -> str:
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
input_len = input_ids.shape[-1]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
generation_config=gen_config,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_tokens = outputs[0][input_len:]
|
||||||
|
decoded_output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
return decoded_output.strip()
|
||||||
|
|
||||||
|
def _create_generation_config(self, request) -> GenerationConfig:
|
||||||
|
def get_param(name, default=None, cast_type=None):
|
||||||
|
val = self._get_input_scalar(request, name, default)
|
||||||
|
if val is not None and cast_type:
|
||||||
|
return cast_type(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
return GenerationConfig(
|
||||||
|
max_length=get_param("max_length", 1024, int),
|
||||||
|
max_new_tokens=get_param("max_new_tokens", 256, int),
|
||||||
|
temperature=get_param("temperature", 1.0, float),
|
||||||
|
do_sample=get_param("do_sample", False, bool),
|
||||||
|
top_k=get_param("top_k", 50, int),
|
||||||
|
top_p=get_param("top_p", 1.0, float),
|
||||||
|
repetition_penalty=get_param("repetition_penalty", 1.0, float),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_response(self, output_text: str, request_id: str):
|
||||||
|
"""생성된 텍스트를 Triton Response 객체로 변환"""
|
||||||
|
output_tensor = pb_utils.Tensor(
|
||||||
|
"text_output",
|
||||||
|
np.array([output_text.encode('utf-8')], dtype=np.bytes_)
|
||||||
|
)
|
||||||
|
return pb_utils.InferenceResponse(output_tensors=[output_tensor])
|
||||||
|
|
||||||
|
def _get_config_param(self, key: str, default: str = None) -> str:
|
||||||
|
params = self.model_config.get('parameters', {})
|
||||||
|
if key in params:
|
||||||
|
return params[key].get('string_value', default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
def _get_input_scalar(self, request, name: str, default=None):
|
||||||
|
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
||||||
|
if tensor is None:
|
||||||
|
return default
|
||||||
|
return self._np_decoder(tensor.as_numpy()[0])
|
||||||
|
|
||||||
|
def _np_decoder(self, obj):
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8')
|
||||||
|
if np.issubdtype(obj, np.integer):
|
||||||
|
return int(obj)
|
||||||
|
if np.issubdtype(obj, np.floating):
|
||||||
|
return round(float(obj), 3)
|
||||||
|
if isinstance(obj, np.bool_):
|
||||||
|
return bool(obj)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
self.logger.log_info(f"Finalizing model {self.model_name}")
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
BIN
Qwen3-4B-Instruct-2507-mahjong-alpha.gguf
(Stored with Git LFS)
BIN
Qwen3-4B-Instruct-2507-mahjong-alpha.gguf
(Stored with Git LFS)
Binary file not shown.
251
README.md
251
README.md
@ -1,251 +0,0 @@
|
|||||||
---
|
|
||||||
license: apache-2.0
|
|
||||||
datasets:
|
|
||||||
- pjura/mahjong_board_states
|
|
||||||
language:
|
|
||||||
- zh
|
|
||||||
base_model:
|
|
||||||
- unsloth/Qwen3-4B-Instruct-2507
|
|
||||||
tags:
|
|
||||||
- riichi-mahjong
|
|
||||||
- game-ai
|
|
||||||
- qwen
|
|
||||||
- qwen3
|
|
||||||
- mahjong
|
|
||||||
- discard-recommendation
|
|
||||||
- gguf
|
|
||||||
pipeline_tag: text-generation
|
|
||||||
---
|
|
||||||
|
|
||||||
# Qwen3-4B-Instruct-2507-mahjong-alpha
|
|
||||||
|
|
||||||
`Qwen3-4B-Instruct-2507-mahjong-alpha` 是一个基于 `unsloth/Qwen3-4B-Instruct-2507` 进行 QLoRA 微调的立直麻将垂直模型,面向四麻弃牌建议任务。
|
|
||||||
|
|
||||||
模型可根据输入的场次信息、手牌、副露、牌河、牌效与防守信息,输出当前最应打出的一张牌。
|
|
||||||
|
|
||||||
当前版本主要面向工具集成场景,推理输出为单张牌文本,不包含解释信息。
|
|
||||||
|
|
||||||
## 模型特点
|
|
||||||
|
|
||||||
- **任务**:四麻立直麻将弃牌建议
|
|
||||||
- **基座模型**:`unsloth/Qwen3-4B-Instruct-2507`
|
|
||||||
- **微调方式**:`QLoRA`
|
|
||||||
- **训练框架**:`Unsloth`
|
|
||||||
- **发布格式**:`GGUF (F16)`
|
|
||||||
- **推理方式**:`llama.cpp`
|
|
||||||
- **维护者**:`TTDXQ`
|
|
||||||
|
|
||||||
## 适用范围
|
|
||||||
|
|
||||||
本模型面向四麻场景,不含赤宝牌。当前版本专注于"弃牌建议"这一单一任务,不提供完整对局规划,也不提供役种、打点或详细攻防解释。
|
|
||||||
|
|
||||||
## 使用限制
|
|
||||||
|
|
||||||
- 仅支持弃牌建议
|
|
||||||
- 不支持完整对局规划
|
|
||||||
- 不支持役种、打点、进攻防守解释
|
|
||||||
- 不保证竞赛或实战效果
|
|
||||||
- 仅供研究与学习使用
|
|
||||||
|
|
||||||
## 禁止用途
|
|
||||||
|
|
||||||
禁止将本模型用于:
|
|
||||||
|
|
||||||
- 作弊
|
|
||||||
- 外挂
|
|
||||||
- 代打
|
|
||||||
- 真钱赌博辅助
|
|
||||||
|
|
||||||
## 模型输入输出
|
|
||||||
|
|
||||||
### 输入格式
|
|
||||||
|
|
||||||
模型输入为结构化自然语言局面描述。示例:
|
|
||||||
|
|
||||||
```text
|
|
||||||
[情景分析]
|
|
||||||
- 牌局: 东一局,你是庄家 (第1巡,牌墙余69张)。
|
|
||||||
- 状态: 当前排名 1/4 (与一位差 0)。
|
|
||||||
- 宝牌: 5万
|
|
||||||
- 各玩家分数: 你有 25分, 下家: 25分, 对家: 25分, 上家: 25分。
|
|
||||||
- 你的手牌: 1万 5万 7万 3筒 5筒 6筒 8筒 8筒 3索 5索 8索 南 白 发
|
|
||||||
- 牌效: 5 向听,进张 82 张。
|
|
||||||
- 防御:
|
|
||||||
最安全牌放铳率:11.3%
|
|
||||||
平均放铳率:18.5%
|
|
||||||
最危险牌放铳率:25.9%
|
|
||||||
场上已见牌信息
|
|
||||||
各玩家副露信息:本家副露:无, 下家副露:无, 对家副露:无, 上家副露:无
|
|
||||||
各玩家牌河信息:本家:无, 下家:无, 对家:无, 上家:无
|
|
||||||
|
|
||||||
[任务]
|
|
||||||
根据当前情景,选择一张最应该打出的手牌。
|
|
||||||
```
|
|
||||||
|
|
||||||
### 输出格式
|
|
||||||
|
|
||||||
模型输出严格为"单张牌文本",不带"打"字,不带解释。例如:
|
|
||||||
|
|
||||||
```text
|
|
||||||
白
|
|
||||||
```
|
|
||||||
|
|
||||||
## 使用方法
|
|
||||||
|
|
||||||
### llama.cpp 推理
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama-server -m Qwen3-4B-Instruct-2507-mahjong-alpha.gguf -c 2048
|
|
||||||
```
|
|
||||||
|
|
||||||
### Python 推理示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
"TTDXQ/Qwen3-4B-Instruct-2507-mahjong-alpha"
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
"TTDXQ/Qwen3-4B-Instruct-2507-mahjong-alpha"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 准备输入
|
|
||||||
input_text = "[情景分析]\n- 牌局: 东一局,你是庄家 (第1巡,牌墙余69张)。\n..."
|
|
||||||
|
|
||||||
# 推理
|
|
||||||
inputs = tokenizer(input_text, return_tensors="pt")
|
|
||||||
outputs = model.generate(**inputs, max_new_tokens=10)
|
|
||||||
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
print(result) # 输出: 白
|
|
||||||
```
|
|
||||||
|
|
||||||
## 数据集
|
|
||||||
|
|
||||||
训练数据使用 `pjura/mahjong_board_states` 的 2018 年部分数据。该数据集来源于天风麻将的游玩记录,每条数据包含 511 个数据点,涵盖游戏基础信息、宝牌指示牌、视角玩家手牌、玩家副露、牌河信息、玩家舍牌、弃牌决策等。
|
|
||||||
|
|
||||||
### 数据处理
|
|
||||||
|
|
||||||
将原始数据转换为便于阅读的自然语言描述形式,并根据数据计算出巡目数、实际宝牌、简易放铳参考等信息。根据巡目调整样本比例:
|
|
||||||
|
|
||||||
- 1~3 巡:15%
|
|
||||||
- 4~6 巡:20%
|
|
||||||
- 7~12 巡:35%
|
|
||||||
|
|
||||||
最终使用 `192000` 条样本,未混入通用指令数据或自建数据。
|
|
||||||
|
|
||||||
- 训练集:`192000`
|
|
||||||
- 验证集:`2000`
|
|
||||||
- 测试集:`2019 年数据按需抽取`
|
|
||||||
- 训练 / 验证 / 测试:完全互不重叠
|
|
||||||
|
|
||||||
### 数据集引用
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@dataset{mahjong_board_states,
|
|
||||||
title = {MahJong Board States Dataset},
|
|
||||||
author = {Patrick Jura},
|
|
||||||
year = {2024},
|
|
||||||
url = {https://huggingface.co/datasets/pjura/mahjong_board_states}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 训练信息
|
|
||||||
|
|
||||||
### 模型配置
|
|
||||||
- 基础模型:`unsloth/Qwen3-4B-Instruct-2507`
|
|
||||||
- 训练加载精度:`4bit`
|
|
||||||
- 微调方式:`QLoRA`
|
|
||||||
- 训练框架:`Unsloth`
|
|
||||||
- Max sequence length:`2048`
|
|
||||||
|
|
||||||
### LoRA 参数
|
|
||||||
- Rank:`128`
|
|
||||||
- Alpha:`256`
|
|
||||||
- 目标模块:全部
|
|
||||||
|
|
||||||
### 训练超参数
|
|
||||||
- Learning rate:`1e-4`
|
|
||||||
- LR scheduler:`cosine`
|
|
||||||
- Batch size:`64`
|
|
||||||
- 单卡批次:`2`
|
|
||||||
- 梯度累积步数:`32`
|
|
||||||
- Training steps:`3000`
|
|
||||||
- Warmup steps:`300`
|
|
||||||
- Random seed:`3407`
|
|
||||||
- 加载最优检查点:是
|
|
||||||
|
|
||||||
### 训练时间
|
|
||||||
- 总时长:约 16.44 小时
|
|
||||||
|
|
||||||
## 评测结果
|
|
||||||
|
|
||||||
### 与数据库弃牌动作对比
|
|
||||||
|
|
||||||
推理参数:Temperature=0.1, Top_P=0.1
|
|
||||||
|
|
||||||
**评测指标说明**:
|
|
||||||
- 得分:满分 500 分(每个样本正确得 1 分,错误得 0 分)
|
|
||||||
- 样本全对率:3 次测试均与测试集结果一致的样本占全部样本的比例
|
|
||||||
- 样本零分率:3 次测试均与测试集结果不符的样本占全部样本的比例
|
|
||||||
|
|
||||||
#### 牌效测试
|
|
||||||
|
|
||||||
| 模型 | 方法 | 得分 | 样本全对率 | 样本零分率 |
|
|
||||||
|------|------|------|------------|------------|
|
|
||||||
| Qwen3-4B | 提示词工程 | 50.21 | 6.60% | 86.13% |
|
|
||||||
| Qwen3-4B | 微调 | 229.66 | 45.87% | 53.93% |
|
|
||||||
| DeepSeek-V3.2 | 提示词工程 | 181.66 | 21.40% | 46.33% |
|
|
||||||
|
|
||||||
#### 防守测试
|
|
||||||
|
|
||||||
| 模型 | 方法 | 得分 | 样本全对率 | 样本零分率 |
|
|
||||||
|------|------|------|------------|------------|
|
|
||||||
| Qwen3-4B | 提示词工程 | 53.55 | 6.17% | 84.43% |
|
|
||||||
| Qwen3-4B | 微调 | 239.89 | 47.93% | 52.00% |
|
|
||||||
| DeepSeek-V3.2 | 提示词工程 | 172.00 | 16.00% | 46.80% |
|
|
||||||
|
|
||||||
#### 综合测试
|
|
||||||
|
|
||||||
| 模型 | 方法 | 得分 | 样本全对率 | 样本零分率 |
|
|
||||||
|------|------|------|------------|------------|
|
|
||||||
| Qwen3-4B | 提示词工程 | 53.44 | 0.60% | 84.40% |
|
|
||||||
| Qwen3-4B | 微调 | 233.33 | 46.53% | 53.20% |
|
|
||||||
| DeepSeek-V3.2 | 提示词工程 | 179.44 | 18.07% | 44.93% |
|
|
||||||
|
|
||||||
### 与 Mortal 对比
|
|
||||||
|
|
||||||
推理参数:Temperature=0.6, Top_P=0.95
|
|
||||||
|
|
||||||
#### 测试1:全部巡目数据
|
|
||||||
|
|
||||||
- 样本数:3000
|
|
||||||
- Top-1 准确率:**50.73%**
|
|
||||||
- Top-3 准确率:**83.37%**
|
|
||||||
|
|
||||||
#### 测试2:去除早巡数据
|
|
||||||
|
|
||||||
- 有效样本数:3000
|
|
||||||
- Top-1 准确率:**48.70%**
|
|
||||||
- Top-3 准确率:**79.20%**
|
|
||||||
|
|
||||||
> 注:Mortal 是当前开源最强的立直麻将 AI 之一
|
|
||||||
|
|
||||||
## 仓库链接
|
|
||||||
|
|
||||||
- GitHub:https://github.com/ttdxq/Qwen3-4B-Instruct-2507-mahjong-alpha
|
|
||||||
- Hugging Face:https://huggingface.co/TTDXQ/Qwen3-4B-Instruct-2507-mahjong-alpha
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
本模型遵循 Apache License 2.0 许可证。
|
|
||||||
|
|
||||||
训练数据来自 `pjura/mahjong_board_states`,其许可证为 `CC BY 4.0`,使用时请保留相应署名与引用。
|
|
||||||
|
|
||||||
## Acknowledgements
|
|
||||||
|
|
||||||
感谢以下开源资源:
|
|
||||||
|
|
||||||
- `unsloth/Qwen3-4B-Instruct-2507`
|
|
||||||
- `pjura/mahjong_board_states`
|
|
||||||
- `Mortal`
|
|
||||||
257
README_en.md
257
README_en.md
@ -1,257 +0,0 @@
|
|||||||
---
|
|
||||||
license: apache-2.0
|
|
||||||
datasets:
|
|
||||||
- pjura/mahjong_board_states
|
|
||||||
language:
|
|
||||||
- zh
|
|
||||||
- en
|
|
||||||
base_model:
|
|
||||||
- unsloth/Qwen3-4B-Instruct-2507
|
|
||||||
tags:
|
|
||||||
- riichi-mahjong
|
|
||||||
- game-ai
|
|
||||||
- qwen
|
|
||||||
- qwen3
|
|
||||||
- mahjong
|
|
||||||
- discard-recommendation
|
|
||||||
- gguf
|
|
||||||
pipeline_tag: text-generation
|
|
||||||
---
|
|
||||||
|
|
||||||
# Qwen3-4B-Instruct-2507-mahjong-alpha
|
|
||||||
|
|
||||||
[中文](./README.md)
|
|
||||||
|
|
||||||
`Qwen3-4B-Instruct-2507-mahjong-alpha` is a Riichi Mahjong domain model fine-tuned from `unsloth/Qwen3-4B-Instruct-2507` with QLoRA.
|
|
||||||
|
|
||||||
It is designed for 4-player Riichi Mahjong discard recommendation: given round information, hand tiles, calls, visible tiles, tile-efficiency, and defense signals, the model outputs the single best discard tile for the current state.
|
|
||||||
|
|
||||||
The current version is mainly intended for tool integration. The output is a single tile text only, without explanation.
|
|
||||||
|
|
||||||
## Model Features
|
|
||||||
|
|
||||||
- **Task**: 4-player Riichi Mahjong discard recommendation
|
|
||||||
- **Base model**: `unsloth/Qwen3-4B-Instruct-2507`
|
|
||||||
- **Fine-tuning**: `QLoRA`
|
|
||||||
- **Training framework**: `Unsloth`
|
|
||||||
- **Release format**: `GGUF (F16)`
|
|
||||||
- **Inference**: `llama.cpp`
|
|
||||||
- **Maintainer**: `TTDXQ`
|
|
||||||
|
|
||||||
## Scope
|
|
||||||
|
|
||||||
This model targets 4-player Riichi Mahjong without red dora. The current version focuses only on discard recommendation. It does not provide full-game planning, yaku/score analysis, or detailed offense-defense explanations.
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- Discard recommendation only
|
|
||||||
- No full-game planning
|
|
||||||
- No yaku, point calculation, or detailed strategic explanation
|
|
||||||
- Not guaranteed for competitive or real-match performance
|
|
||||||
- For research and learning purposes only
|
|
||||||
|
|
||||||
## Prohibited Uses
|
|
||||||
|
|
||||||
This model must not be used for:
|
|
||||||
|
|
||||||
- cheating
|
|
||||||
- game automation or plug-ins
|
|
||||||
- account boosting or ghost-playing
|
|
||||||
- real-money gambling assistance
|
|
||||||
|
|
||||||
## Input and Output
|
|
||||||
|
|
||||||
### Input Format
|
|
||||||
|
|
||||||
The model input is a structured natural-language game-state description. Example:
|
|
||||||
|
|
||||||
```text
|
|
||||||
[情景分析]
|
|
||||||
- 牌局: 东一局,你是庄家 (第1巡,牌墙余69张)。
|
|
||||||
- 状态: 当前排名 1/4 (与一位差 0)。
|
|
||||||
- 宝牌: 5万
|
|
||||||
- 各玩家分数: 你有 25分, 下家: 25分, 对家: 25分, 上家: 25分。
|
|
||||||
- 你的手牌: 1万 5万 7万 3筒 5筒 6筒 8筒 8筒 3索 5索 8索 南 白 发
|
|
||||||
- 牌效: 5 向听,进张 82 张。
|
|
||||||
- 防御:
|
|
||||||
最安全牌放铳率:11.3%
|
|
||||||
平均放铳率:18.5%
|
|
||||||
最危险牌放铳率:25.9%
|
|
||||||
场上已见牌信息
|
|
||||||
各玩家副露信息:本家副露:无, 下家副露:无, 对家副露:无, 上家副露:无
|
|
||||||
各玩家牌河信息:本家:无, 下家:无, 对家:无, 上家:无
|
|
||||||
|
|
||||||
[任务]
|
|
||||||
根据当前情景,选择一张最应该打出的手牌。
|
|
||||||
```
|
|
||||||
|
|
||||||
### Output Format
|
|
||||||
|
|
||||||
The output is strictly a single tile text without any prefix like "discard" and without explanation. Example:
|
|
||||||
|
|
||||||
```text
|
|
||||||
白
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### llama.cpp Inference
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama-server -m Qwen3-4B-Instruct-2507-mahjong-alpha.gguf -c 2048
|
|
||||||
```
|
|
||||||
|
|
||||||
### Python Inference Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
"TTDXQ/Qwen3-4B-Instruct-2507-mahjong-alpha"
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
"TTDXQ/Qwen3-4B-Instruct-2507-mahjong-alpha"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare input
|
|
||||||
input_text = """[情景分析]
|
|
||||||
- 牌局: 东一局,你是庄家 (第1巡,牌墙余69张)。
|
|
||||||
- 状态: 当前排名 1/4 (与一位差 0)。
|
|
||||||
..."""
|
|
||||||
|
|
||||||
# Inference
|
|
||||||
inputs = tokenizer(input_text, return_tensors="pt")
|
|
||||||
outputs = model.generate(**inputs, max_new_tokens=10)
|
|
||||||
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
print(result) # Output: 白
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset
|
|
||||||
|
|
||||||
The training data uses the 2018 subset of `pjura/mahjong_board_states`. This dataset originates from Tenhou.net gameplay records, with each record containing 511 data points covering game basics, dora indicators, player hand tiles, calls, discard piles, and discard decisions.
|
|
||||||
|
|
||||||
### Data Processing
|
|
||||||
|
|
||||||
The raw data was converted into human-readable natural language descriptions, with calculated turn numbers, actual dora, and simplified risk assessment. Sample distribution by turn:
|
|
||||||
|
|
||||||
- Turns 1-3: 15%
|
|
||||||
- Turns 4-6: 20%
|
|
||||||
- Turns 7-12: 35%
|
|
||||||
|
|
||||||
A total of `192000` samples were used, with no general instruction data or self-built data mixed in.
|
|
||||||
|
|
||||||
- Train: `192000`
|
|
||||||
- Validation: `2000`
|
|
||||||
- Test: sampled as needed from 2019 data
|
|
||||||
- Train / validation / test are fully non-overlapping
|
|
||||||
|
|
||||||
### Dataset Citation
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@dataset{mahjong_board_states,
|
|
||||||
title = {MahJong Board States Dataset},
|
|
||||||
author = {Patrick Jura},
|
|
||||||
year = {2024},
|
|
||||||
url = {https://huggingface.co/datasets/pjura/mahjong_board_states}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Training Details
|
|
||||||
|
|
||||||
### Model Configuration
|
|
||||||
- Base Model: `unsloth/Qwen3-4B-Instruct-2507`
|
|
||||||
- Training Precision: `4bit`
|
|
||||||
- Fine-tuning Method: `QLoRA`
|
|
||||||
- Framework: `Unsloth`
|
|
||||||
- Max Sequence Length: `2048`
|
|
||||||
|
|
||||||
### LoRA Parameters
|
|
||||||
- Rank: `128`
|
|
||||||
- Alpha: `256`
|
|
||||||
- Target Modules: All
|
|
||||||
|
|
||||||
### Training Hyperparameters
|
|
||||||
- Learning Rate: `1e-4`
|
|
||||||
- LR Scheduler: `cosine`
|
|
||||||
- Batch Size: `64`
|
|
||||||
- Per-device Batch: `2`
|
|
||||||
- Gradient Accumulation Steps: `32`
|
|
||||||
- Training Steps: `3000`
|
|
||||||
- Warmup Steps: `300`
|
|
||||||
- Random Seed: `3407`
|
|
||||||
- Load Best Checkpoint: Yes
|
|
||||||
|
|
||||||
### Training Time
|
|
||||||
- Total Duration: ~16.44 hours
|
|
||||||
|
|
||||||
## Evaluation Results
|
|
||||||
|
|
||||||
### Comparison with Dataset Actions
|
|
||||||
|
|
||||||
Inference parameters: Temperature=0.1, Top_P=0.1
|
|
||||||
|
|
||||||
**Metrics explanation**:
|
|
||||||
- Score: Max 500 points (1 point per correct sample, 0 for incorrect)
|
|
||||||
- Full-match rate: Samples where all 3 tests matched the dataset
|
|
||||||
- Zero-score rate: Samples where all 3 tests disagreed with the dataset
|
|
||||||
|
|
||||||
#### Tile-Efficiency Test
|
|
||||||
|
|
||||||
| Model | Method | Score | Full-match Rate | Zero-score Rate |
|
|
||||||
|-------|--------|-------|----------------|-----------------|
|
|
||||||
| Qwen3-4B | Prompt Engineering | 50.21 | 6.60% | 86.13% |
|
|
||||||
| Qwen3-4B | Fine-tuned | 229.66 | 45.87% | 53.93% |
|
|
||||||
| DeepSeek-V3.2 | Prompt Engineering | 181.66 | 21.40% | 46.33% |
|
|
||||||
|
|
||||||
#### Defense Test
|
|
||||||
|
|
||||||
| Model | Method | Score | Full-match Rate | Zero-score Rate |
|
|
||||||
|-------|--------|-------|----------------|-----------------|
|
|
||||||
| Qwen3-4B | Prompt Engineering | 53.55 | 6.17% | 84.43% |
|
|
||||||
| Qwen3-4B | Fine-tuned | 239.89 | 47.93% | 52.00% |
|
|
||||||
| DeepSeek-V3.2 | Prompt Engineering | 172.00 | 16.00% | 46.80% |
|
|
||||||
|
|
||||||
#### Comprehensive Test
|
|
||||||
|
|
||||||
| Model | Method | Score | Full-match Rate | Zero-score Rate |
|
|
||||||
|-------|--------|-------|----------------|-----------------|
|
|
||||||
| Qwen3-4B | Prompt Engineering | 53.44 | 0.60% | 84.40% |
|
|
||||||
| Qwen3-4B | Fine-tuned | 233.33 | 46.53% | 53.20% |
|
|
||||||
| DeepSeek-V3.2 | Prompt Engineering | 179.44 | 18.07% | 44.93% |
|
|
||||||
|
|
||||||
### Comparison with Mortal
|
|
||||||
|
|
||||||
Inference parameters: Temperature=0.6, Top_P=0.95
|
|
||||||
|
|
||||||
#### Test 1: All Turn Data
|
|
||||||
|
|
||||||
- Samples: 3000
|
|
||||||
- Top-1 Accuracy: **50.73%**
|
|
||||||
- Top-3 Accuracy: **83.37%**
|
|
||||||
|
|
||||||
#### Test 2: Excluding Early Turns
|
|
||||||
|
|
||||||
- Valid Samples: 3000
|
|
||||||
- Top-1 Accuracy: **48.70%**
|
|
||||||
- Top-3 Accuracy: **79.20%**
|
|
||||||
|
|
||||||
> Note: Mortal is one of the strongest open-source Riichi Mahjong AIs currently available
|
|
||||||
|
|
||||||
## Repository Links
|
|
||||||
|
|
||||||
- GitHub: https://github.com/ttdxq/Qwen3-4B-Instruct-2507-mahjong-alpha
|
|
||||||
- Hugging Face: https://huggingface.co/TTDXQ/Qwen3-4B-Instruct-2507-mahjong-alpha
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
This model is licensed under Apache License 2.0.
|
|
||||||
|
|
||||||
The training data comes from `pjura/mahjong_board_states`, which is licensed under `CC BY 4.0`. Please preserve the required attribution and citation when using it.
|
|
||||||
|
|
||||||
## Acknowledgements
|
|
||||||
|
|
||||||
Thanks to the following open-source resources:
|
|
||||||
|
|
||||||
- `unsloth/Qwen3-4B-Instruct-2507`
|
|
||||||
- `pjura/mahjong_board_states`
|
|
||||||
- `Mortal`
|
|
||||||
135
config.pbtxt
Normal file
135
config.pbtxt
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
# Triton Backend for TransformerLLM.
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 0
|
||||||
|
|
||||||
|
# Triton should expect as input a single string
|
||||||
|
# input of variable length named 'text_input'
|
||||||
|
input [
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "text_input"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max_length"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max_new_tokens"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "do_sample"
|
||||||
|
data_type: TYPE_BOOL
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top_k"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top_p"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temperature"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repetition_penalty"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stream"
|
||||||
|
data_type: TYPE_BOOL
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Triton should expect to respond with a single string
|
||||||
|
# output of variable length named 'text_output'
|
||||||
|
output [
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "text_output"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
parameters: [
|
||||||
|
{
|
||||||
|
key: "base_model_path",
|
||||||
|
value: {string_value: "/cheetah/input/model/groupuser/Qwen3-4B-Instruct-2507-mahjong-alpha"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "gguf_filename",
|
||||||
|
value: {string_value: "Qwen3-4B-Instruct-2507-mahjong-alpha.gguf"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "is_adapter_model",
|
||||||
|
value: {string_value: "false"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "adapter_model_path",
|
||||||
|
value: {string_value: ""}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
key: "quantization",
|
||||||
|
value: {string_value: "none"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
kind: KIND_AUTO
|
||||||
|
count: 1
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user