This commit is contained in:
2025-09-26 16:55:57 +08:00
parent 57a7e9216e
commit 0597239207
12 changed files with 1412 additions and 4 deletions

View File

@@ -0,0 +1,429 @@
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from poker.card import Card, ShortDeckRank, Suit
@dataclass
class RiverEHSRecord:
"""river_EHS """
board_id: int
player_id: int
ehs: float
board_cards: List[Card] # 5张公共牌
player_cards: List[Card] # 2张手牌
@dataclass
class TurnHistRecord:
"""turn_hist"""
board_id: int
player_id: int
bins: np.ndarray
board_cards: List[Card] # 4张公共牌
player_cards: List[Card] # 2张手牌
@dataclass
class FlopHistRecord:
"""flop_hist"""
board_id: int
player_id: int
bins: np.ndarray
board_cards: List[Card] # 3张公共牌
player_cards: List[Card] # 2张手牌
class OpenPQLDecoder:
"""open-pql 完整解码"""
def __init__(self):
# Card64
self.OFFSET_SUIT = 16 # 每个suit占16位
self.OFFSET_S = 0 # S: [15:0]
self.OFFSET_H = 16 # H: [31:16]
self.OFFSET_D = 32 # D: [47:32]
self.OFFSET_C = 48 # C: [63:48]
# open-pql 到 短牌型(6-A36张牌)
self.opql_to_poker_rank = {
# 0: Rank.TWO,
# 1: Rank.THREE,
# 2: Rank.FOUR,
# 3: Rank.FIVE,
4: ShortDeckRank.SIX,
5: ShortDeckRank.SEVEN,
6: ShortDeckRank.EIGHT,
7: ShortDeckRank.NINE,
8: ShortDeckRank.TEN,
9: ShortDeckRank.JACK,
10: ShortDeckRank.QUEEN,
11: ShortDeckRank.KING,
12: ShortDeckRank.ACE,
}
self.opql_to_poker_suit = {
0: Suit.SPADES, # S -> SPADES
1: Suit.HEARTS, # H -> HEARTS
2: Suit.DIAMONDS, # D -> DIAMONDS
3: Suit.CLUBS, # C -> CLUBS
}
print("========== OpenPQLDecoder (短牌型36张牌) ===============")
def u64_from_ranksuit(self, rank, suit) -> int:
"""对应 Card64::u64_from_ranksuit_i8"""
return 1 << rank << (suit * self.OFFSET_SUIT)
def decode_card64(self, card64_u64) -> List[Card]:
"""从 Card64 的 u64 表示解码出所有牌"""
cards = []
# 按照 Card64 解析每个suit
# 每个suit16位
for opql_suit in range(4): # 0,1,2,3
suit_offset = opql_suit * self.OFFSET_SUIT
suit_bits = (card64_u64 >> suit_offset) & 0xFFFF
# 检查这个suit的每个rank
for opql_rank in range(13): # 0-12
if suit_bits & (1 << opql_rank):
poker_rank = self.opql_to_poker_rank[opql_rank]
poker_suit = self.opql_to_poker_suit[opql_suit]
cards.append(Card(poker_rank, poker_suit))
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
return cards
def decode_hand_n_2(self, id_u16) -> List[Card]:
"""解码 Hand<2> 的 u16 ID"""
card0_u8 = id_u16 & 0xFF # 低8位
card1_u8 = (id_u16 >> 8) & 0xFF # 高8位
try:
card0 = self._card_from_u8(card0_u8)
card1 = self._card_from_u8(card1_u8)
cards = [card0, card1]
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
return cards
except (ValueError, TypeError) as e:
raise ValueError(f"解码 Hand<2> 失败: id_u16={id_u16:04x}, card0_u8={card0_u8:02x}, card1_u8={card1_u8:02x}: {e}")
def _card_from_u8(self, v) -> Card:
"""
对应 open-pql Card::from_u8()
单张牌decode
"""
SHIFT_SUIT = 4
opql_rank = v & 0b1111 # 低4位
opql_suit = v >> SHIFT_SUIT # 高4位
if opql_rank not in self.opql_to_poker_rank:
raise ValueError(f"无效的rank: {opql_rank}")
if opql_suit not in self.opql_to_poker_suit:
raise ValueError(f"无效的suit: {opql_suit}")
poker_rank = self.opql_to_poker_rank[opql_rank]
poker_suit = self.opql_to_poker_suit[opql_suit]
return Card(poker_rank, poker_suit)
def decode_board_id(self, board_id, num_cards) -> List[Card]:
"""解码公共牌 ID"""
if num_cards == 2:
# 2张牌使用 Hand<2> 的编码方式
return self.decode_hand_n_2(board_id)
else:
# 3, 4, 5张公共牌使用 Card64 的编码方式
cards = self.decode_card64(board_id)
if len(cards) != num_cards:
raise ValueError(f"解码出 {len(cards)} 张牌,期望 {num_cards} 张,board_id={board_id:016x}")
return cards
def decode_player_id(self, player_id) -> List[Card]:
"""解码玩家手牌 ID (2张牌)"""
return self.decode_hand_n_2(player_id)
def decode_board_unique_card(self, all_cards) -> List[Card]:
"""验证并返回不重复的牌面"""
unique_cards = []
for card in all_cards:
is_duplicate = False
for existing_card in unique_cards:
if card.rank == existing_card.rank and card.suit == existing_card.suit:
is_duplicate = True
break
if not is_duplicate:
unique_cards.append(card)
return unique_cards
class XTaskDataParser:
def __init__(self, data_path: str = "ehs_data"):
self.data_path = data_path
self.decoder = OpenPQLDecoder()
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]:
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
records = []
decode_errors = 0
for i, row in enumerate(raw_data):
try:
board_id = int(row['board'])
player_id = int(row['player'])
ehs = float(row['ehs'])
# 解码公共牌 (5张)
board_cards = self.decoder.decode_board_id(board_id, 5)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 7:
print(f"记录 {i}: 存在重复牌面")
print(f" Board: {[str(c) for c in board_cards]}")
print(f" Player: {[str(c) for c in player_cards]}")
# 创建完整记录
record = RiverEHSRecord(
board_id=board_id,
player_id=player_id,
ehs=ehs,
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
print(f" board_id={row['board']:016x}, player_id={row['player']:04x}")
continue
# 显示进度
if (i + 1) % 10000 == 0:
success_rate = len(records) / (i + 1) * 100
print(f" 已处理 {i+1:,}/{len(raw_data):,} 条记录,成功率 {success_rate:.1f}%")
total_records = len(raw_data)
success_records = len(records)
success_rate = success_records / total_records * 100
print(f"river_EHS解析完成:")
print(f" 总记录数: {total_records:,}")
print(f" 成功解码: {success_records:,} ({success_rate:.1f}%)")
print(f" 解码失败: {decode_errors:,}")
return records
except FileNotFoundError:
print(f"文件不存在: {filepath}")
raise
except Exception as e:
print(f"解析river_EHS数据失败: {e}")
raise
def parse_turn_hist_with_cards(self, filename: str = "turn_hist.npy", max_records: int = 1000) -> List[TurnHistRecord]:
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载turn_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = []
decode_errors = 0
# 限制处理数量以避免内存问题
# todo:抽样优化
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
bins = row['bins'].copy()
# 解码公共牌 (4张)
board_cards = self.decoder.decode_board_id(board_id, 4)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 6:
decode_errors += 1
print(f"记录 {i}: 存在重复牌面")
record = TurnHistRecord(
board_id=board_id,
player_id=player_id,
bins=bins, # 存储30个bins的numpy数组
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
continue
if (i + 1) % 100 == 0:
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
print(f"turn_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
return records
except Exception as e:
print(f" 解析turn_hist数据失败: {e}")
raise
def parse_flop_hist_with_cards(self, filename: str = "flop_hist.npy", max_records: int = 500) -> List[FlopHistRecord]:
"""解析flop_hist"""
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = []
decode_errors = 0
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
bins = row['bins'].copy()
# 解码公共牌 (3张)
board_cards = self.decoder.decode_board_id(board_id, 3)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 5:
decode_errors += 1
print(f"记录 {i}: 存在重复牌面")
record = FlopHistRecord(
board_id=board_id,
player_id=player_id,
bins=bins, # 存储465个bins的numpy数组
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
continue
if (i + 1) % 50 == 0:
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
print(f"flop_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
return records
except Exception as e:
print(f"解析flop_hist数据失败: {e}")
raise
def analyze_parsed_data():
print("开始解析 xtask 数据并解码牌面...\n")
try:
parser = XTaskDataParser(data_path="ehs_data")
# 1. 解析river_EHS数据
print("=" * 60)
print(" 解析river_EHS")
print("=" * 60)
river_records = parser.parse_river_ehs_with_cards()
# 显示river_EHS样本
print(f"\nriver_EHS数据样本 (前10条):")
for i, record in enumerate(river_records[:10]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
print(f" {i+1:2d}. Board:[{board_str:14s}] Player:[{player_str:5s}] EHS:{record.ehs:.4f}")
# 2. 解析turn_hist数据
print(f"\n" + "=" * 60)
print(" 解析turn_hist数据")
print("=" * 60)
turn_records = parser.parse_turn_hist_with_cards(max_records=100) # 限制数量
# 显示turn_hist数据样本
print(f"\nturn_hist数据样本 (前5条):")
for i, record in enumerate(turn_records[:5]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
print(f" {i+1}. Board:[{board_str:11s}] Player:[{player_str:5s}] Bins:{bins_stats}")
# 3. 解析flop_hist数据
print(f"\n" + "=" * 60)
print(" 解析flop_hist数据")
print("=" * 60)
flop_records = parser.parse_flop_hist_with_cards(max_records=50) # 限制数量
# 显示flop_数据样本
print(f"\n flop_hist数据样本 (前5条):")
for i, record in enumerate(flop_records[:5]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
print(f" {i+1}. Board:[{board_str:8s}] Player:[{player_str:5s}] Bins:{bins_stats}")
# 4. 统计摘要
print(f"\n" + "=" * 60)
print("解析统计摘要")
print("=" * 60)
print(f"river_EHS记录: {len(river_records):,}")
print(f"turn_hist记录: {len(turn_records):,}")
print(f"flop_hist记录: {len(flop_records):,}")
# EHS 统计
if river_records:
ehs_values = [r.ehs for r in river_records]
print(f"river_EHS统计:")
print(f" 范围: [{min(ehs_values):.4f}, {max(ehs_values):.4f}]")
print(f" 均值: {np.mean(ehs_values):.4f}")
print(f" 标准差: {np.std(ehs_values):.4f}")
print(f"\n数据解析完成!所有 Board ID 和 Player ID 已成功解码为具体牌面。")
return {
'river_records': river_records,
'turn_records': turn_records,
'flop_records': flop_records
}
except Exception as e:
print(f" 数据解析失败: {e}")
return None