Files
poker_task1/cross_validation/parse_data.py
2025-09-28 18:09:17 +08:00

430 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import random
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from poker.card import Card, ShortDeckRank, Suit
@dataclass
class RiverEHSRecord:
"""river_EHS """
board_id: int
player_id: int
ehs: float
board_cards: List[Card] # 5张公共牌
player_cards: List[Card] # 2张手牌
@dataclass
class TurnHistRecord:
"""turn_hist"""
board_id: int
player_id: int
bins: np.ndarray
board_cards: List[Card] # 4张公共牌
player_cards: List[Card] # 2张手牌
@dataclass
class FlopHistRecord:
"""flop_hist"""
board_id: int
player_id: int
bins: np.ndarray
board_cards: List[Card] # 3张公共牌
player_cards: List[Card] # 2张手牌
class OpenPQLDecoder:
"""open-pql 完整解码"""
def __init__(self):
# Card64
self.OFFSET_SUIT = 16 # 每个suit占16位
self.OFFSET_S = 0 # S: [15:0]
self.OFFSET_H = 16 # H: [31:16]
self.OFFSET_D = 32 # D: [47:32]
self.OFFSET_C = 48 # C: [63:48]
# open-pql 到 短牌型(6-A36张牌)
self.opql_to_poker_rank = {
# 0: Rank.TWO,
# 1: Rank.THREE,
# 2: Rank.FOUR,
# 3: Rank.FIVE,
4: ShortDeckRank.SIX,
5: ShortDeckRank.SEVEN,
6: ShortDeckRank.EIGHT,
7: ShortDeckRank.NINE,
8: ShortDeckRank.TEN,
9: ShortDeckRank.JACK,
10: ShortDeckRank.QUEEN,
11: ShortDeckRank.KING,
12: ShortDeckRank.ACE,
}
self.opql_to_poker_suit = {
0: Suit.SPADES, # S -> SPADES
1: Suit.HEARTS, # H -> HEARTS
2: Suit.DIAMONDS, # D -> DIAMONDS
3: Suit.CLUBS, # C -> CLUBS
}
print("========== OpenPQLDecoder (短牌型36张牌) ===============")
def u64_from_ranksuit(self, rank, suit) -> int:
"""对应 Card64::u64_from_ranksuit_i8"""
return 1 << rank << (suit * self.OFFSET_SUIT)
def decode_card64(self, card64_u64) -> List[Card]:
"""从 Card64 的 u64 表示解码出所有牌"""
cards = []
# 按照 Card64 解析每个suit
# 每个suit16位
for opql_suit in range(4): # 0,1,2,3
suit_offset = opql_suit * self.OFFSET_SUIT
suit_bits = (card64_u64 >> suit_offset) & 0xFFFF
# 检查这个suit的每个rank
for opql_rank in range(13): # 0-12
if suit_bits & (1 << opql_rank):
poker_rank = self.opql_to_poker_rank[opql_rank]
poker_suit = self.opql_to_poker_suit[opql_suit]
cards.append(Card(poker_rank, poker_suit))
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
return cards
def decode_hand_n_2(self, id_u16) -> List[Card]:
"""解码 Hand<2> 的 u16 ID"""
card0_u8 = id_u16 & 0xFF # 低8位
card1_u8 = (id_u16 >> 8) & 0xFF # 高8位
try:
card0 = self._card_from_u8(card0_u8)
card1 = self._card_from_u8(card1_u8)
cards = [card0, card1]
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
return cards
except (ValueError, TypeError) as e:
raise ValueError(f"解码 Hand<2> 失败: id_u16={id_u16:04x}, card0_u8={card0_u8:02x}, card1_u8={card1_u8:02x}: {e}")
def _card_from_u8(self, v) -> Card:
"""
对应 open-pql Card::from_u8()
单张牌decode
"""
SHIFT_SUIT = 4
opql_rank = v & 0b1111 # 低4位
opql_suit = v >> SHIFT_SUIT # 高4位
if opql_rank not in self.opql_to_poker_rank:
raise ValueError(f"无效的rank: {opql_rank}")
if opql_suit not in self.opql_to_poker_suit:
raise ValueError(f"无效的suit: {opql_suit}")
poker_rank = self.opql_to_poker_rank[opql_rank]
poker_suit = self.opql_to_poker_suit[opql_suit]
return Card(poker_rank, poker_suit)
def decode_board_id(self, board_id, num_cards) -> List[Card]:
"""解码公共牌 ID"""
if num_cards == 2:
# 2张牌使用 Hand<2> 的编码方式
return self.decode_hand_n_2(board_id)
else:
# 3, 4, 5张公共牌使用 Card64 的编码方式
cards = self.decode_card64(board_id)
if len(cards) != num_cards:
raise ValueError(f"解码出 {len(cards)} 张牌,期望 {num_cards} 张,board_id={board_id:016x}")
return cards
def decode_player_id(self, player_id) -> List[Card]:
"""解码玩家手牌 ID (2张牌)"""
return self.decode_hand_n_2(player_id)
def decode_board_unique_card(self, all_cards) -> List[Card]:
"""验证并返回不重复的牌面"""
unique_cards = []
for card in all_cards:
is_duplicate = False
for existing_card in unique_cards:
if card.rank == existing_card.rank and card.suit == existing_card.suit:
is_duplicate = True
break
if not is_duplicate:
unique_cards.append(card)
return unique_cards
class XTaskDataParser:
def __init__(self, data_path: str = "ehs_data"):
self.data_path = data_path
self.decoder = OpenPQLDecoder()
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]:
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
# 抽样
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
records = []
decode_errors = 0
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
ehs = float(row['ehs'])
# 解码公共牌 (5张)
board_cards = self.decoder.decode_board_id(board_id, 5)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 7:
print(f"记录 {i}: 存在重复牌面")
print(f" Board: {[str(c) for c in board_cards]}")
print(f" Player: {[str(c) for c in player_cards]}")
# 创建完整记录
record = RiverEHSRecord(
board_id=board_id,
player_id=player_id,
ehs=ehs,
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
print(f" board_id={row['board']:016x}, player_id={row['player']:04x}")
continue
# 显示进度
if (i + 1) % 10000 == 0:
success_rate = len(records) / (i + 1) * 100
print(f" 已处理 {i+1:,}/{len(raw_data):,} 条记录,成功率 {success_rate:.1f}%")
total_records = len(raw_data)
success_records = len(records)
success_rate = success_records / total_records * 100
print(f"river_EHS解析完成:")
print(f" 总记录数: {total_records:,}")
print(f" 成功解码: {success_records:,} ({success_rate:.1f}%)")
print(f" 解码失败: {decode_errors:,}")
return records
except FileNotFoundError:
print(f"文件不存在: {filepath}")
raise
except Exception as e:
print(f"解析river_EHS数据失败: {e}")
raise
def parse_turn_hist_with_cards(self, filename: str = "turn_hist.npy", max_records: int = 1000) -> List[TurnHistRecord]:
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载turn_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = []
decode_errors = 0
# 抽样
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
bins = row['bins'].copy()
# 解码公共牌 (4张)
board_cards = self.decoder.decode_board_id(board_id, 4)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 6:
decode_errors += 1
print(f"记录 {i}: 存在重复牌面")
record = TurnHistRecord(
board_id=board_id,
player_id=player_id,
bins=bins, # 存储30个bins的numpy数组
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
continue
if (i + 1) % 100 == 0:
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
print(f"turn_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
return records
except Exception as e:
print(f" 解析turn_hist数据失败: {e}")
raise
def parse_flop_hist_with_cards(self, filename: str = "flop_hist.npy", max_records: int = 500) -> List[FlopHistRecord]:
"""解析flop_hist"""
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = []
decode_errors = 0
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
bins = row['bins'].copy()
# 解码公共牌 (3张)
board_cards = self.decoder.decode_board_id(board_id, 3)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 5:
decode_errors += 1
print(f"记录 {i}: 存在重复牌面")
record = FlopHistRecord(
board_id=board_id,
player_id=player_id,
bins=bins, # 存储465个bins的numpy数组
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
continue
if (i + 1) % 50 == 0:
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
print(f"flop_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
return records
except Exception as e:
print(f"解析flop_hist数据失败: {e}")
raise
def analyze_parsed_data():
print("开始解析 xtask 数据并解码牌面...\n")
try:
parser = XTaskDataParser(data_path="ehs_data")
# 1. 解析river_EHS数据
print("=" * 60)
print(" 解析river_EHS")
print("=" * 60)
river_records = parser.parse_river_ehs_with_cards()
# 显示river_EHS样本
print(f"\nriver_EHS数据样本 (前10条):")
for i, record in enumerate(river_records[:10]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
print(f" {i+1:2d}. Board:[{board_str:14s}] Player:[{player_str:5s}] EHS:{record.ehs:.4f}")
# 2. 解析turn_hist数据
print(f"\n" + "=" * 60)
print(" 解析turn_hist数据")
print("=" * 60)
turn_records = parser.parse_turn_hist_with_cards(max_records=100) # 限制数量
# 显示turn_hist数据样本
print(f"\nturn_hist数据样本 (前5条):")
for i, record in enumerate(turn_records[:5]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
print(f" {i+1}. Board:[{board_str:11s}] Player:[{player_str:5s}] Bins:{bins_stats}")
# 3. 解析flop_hist数据
print(f"\n" + "=" * 60)
print(" 解析flop_hist数据")
print("=" * 60)
flop_records = parser.parse_flop_hist_with_cards(max_records=50) # 限制数量
# 显示flop_数据样本
print(f"\n flop_hist数据样本 (前5条):")
for i, record in enumerate(flop_records[:5]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
print(f" {i+1}. Board:[{board_str:8s}] Player:[{player_str:5s}] Bins:{bins_stats}")
# 4. 统计摘要
print(f"\n" + "=" * 60)
print("解析统计摘要")
print("=" * 60)
print(f"river_EHS记录: {len(river_records):,}")
print(f"turn_hist记录: {len(turn_records):,}")
print(f"flop_hist记录: {len(flop_records):,}")
# EHS 统计
if river_records:
ehs_values = [r.ehs for r in river_records]
print(f"river_EHS统计:")
print(f" 范围: [{min(ehs_values):.4f}, {max(ehs_values):.4f}]")
print(f" 均值: {np.mean(ehs_values):.4f}")
print(f" 标准差: {np.std(ehs_values):.4f}")
print(f"\n数据解析完成!所有 Board ID 和 Player ID 已成功解码为具体牌面。")
return {
'river_records': river_records,
'turn_records': turn_records,
'flop_records': flop_records
}
except Exception as e:
print(f" 数据解析失败: {e}")
return None