import numpy as np from typing import List, Dict, Tuple, Optional from dataclasses import dataclass from poker.card import Card, ShortDeckRank, Suit @dataclass class RiverEHSRecord: """river_EHS """ board_id: int player_id: int ehs: float board_cards: List[Card] # 5张公共牌 player_cards: List[Card] # 2张手牌 @dataclass class TurnHistRecord: """turn_hist""" board_id: int player_id: int bins: np.ndarray board_cards: List[Card] # 4张公共牌 player_cards: List[Card] # 2张手牌 @dataclass class FlopHistRecord: """flop_hist""" board_id: int player_id: int bins: np.ndarray board_cards: List[Card] # 3张公共牌 player_cards: List[Card] # 2张手牌 class OpenPQLDecoder: """open-pql 完整解码""" def __init__(self): # Card64 self.OFFSET_SUIT = 16 # 每个suit占16位 self.OFFSET_S = 0 # S: [15:0] self.OFFSET_H = 16 # H: [31:16] self.OFFSET_D = 32 # D: [47:32] self.OFFSET_C = 48 # C: [63:48] # open-pql 到 短牌型(6-A,36张牌) self.opql_to_poker_rank = { # 0: Rank.TWO, # 1: Rank.THREE, # 2: Rank.FOUR, # 3: Rank.FIVE, 4: ShortDeckRank.SIX, 5: ShortDeckRank.SEVEN, 6: ShortDeckRank.EIGHT, 7: ShortDeckRank.NINE, 8: ShortDeckRank.TEN, 9: ShortDeckRank.JACK, 10: ShortDeckRank.QUEEN, 11: ShortDeckRank.KING, 12: ShortDeckRank.ACE, } self.opql_to_poker_suit = { 0: Suit.SPADES, # S -> SPADES 1: Suit.HEARTS, # H -> HEARTS 2: Suit.DIAMONDS, # D -> DIAMONDS 3: Suit.CLUBS, # C -> CLUBS } print("========== OpenPQLDecoder (短牌型36张牌) ===============") def u64_from_ranksuit(self, rank, suit) -> int: """对应 Card64::u64_from_ranksuit_i8""" return 1 << rank << (suit * self.OFFSET_SUIT) def decode_card64(self, card64_u64) -> List[Card]: """从 Card64 的 u64 表示解码出所有牌""" cards = [] # 按照 Card64 解析每个suit # 每个suit16位 for opql_suit in range(4): # 0,1,2,3 suit_offset = opql_suit * self.OFFSET_SUIT suit_bits = (card64_u64 >> suit_offset) & 0xFFFF # 检查这个suit的每个rank for opql_rank in range(13): # 0-12 if suit_bits & (1 << opql_rank): poker_rank = self.opql_to_poker_rank[opql_rank] poker_suit = self.opql_to_poker_suit[opql_suit] cards.append(Card(poker_rank, poker_suit)) cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value)) return cards def decode_hand_n_2(self, id_u16) -> List[Card]: """解码 Hand<2> 的 u16 ID""" card0_u8 = id_u16 & 0xFF # 低8位 card1_u8 = (id_u16 >> 8) & 0xFF # 高8位 try: card0 = self._card_from_u8(card0_u8) card1 = self._card_from_u8(card1_u8) cards = [card0, card1] cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value)) return cards except (ValueError, TypeError) as e: raise ValueError(f"解码 Hand<2> 失败: id_u16={id_u16:04x}, card0_u8={card0_u8:02x}, card1_u8={card1_u8:02x}: {e}") def _card_from_u8(self, v) -> Card: """ 对应 open-pql Card::from_u8() 单张牌decode """ SHIFT_SUIT = 4 opql_rank = v & 0b1111 # 低4位 opql_suit = v >> SHIFT_SUIT # 高4位 if opql_rank not in self.opql_to_poker_rank: raise ValueError(f"无效的rank: {opql_rank}") if opql_suit not in self.opql_to_poker_suit: raise ValueError(f"无效的suit: {opql_suit}") poker_rank = self.opql_to_poker_rank[opql_rank] poker_suit = self.opql_to_poker_suit[opql_suit] return Card(poker_rank, poker_suit) def decode_board_id(self, board_id, num_cards) -> List[Card]: """解码公共牌 ID""" if num_cards == 2: # 2张牌使用 Hand<2> 的编码方式 return self.decode_hand_n_2(board_id) else: # 3, 4, 5张公共牌使用 Card64 的编码方式 cards = self.decode_card64(board_id) if len(cards) != num_cards: raise ValueError(f"解码出 {len(cards)} 张牌,期望 {num_cards} 张,board_id={board_id:016x}") return cards def decode_player_id(self, player_id) -> List[Card]: """解码玩家手牌 ID (2张牌)""" return self.decode_hand_n_2(player_id) def decode_board_unique_card(self, all_cards) -> List[Card]: """验证并返回不重复的牌面""" unique_cards = [] for card in all_cards: is_duplicate = False for existing_card in unique_cards: if card.rank == existing_card.rank and card.suit == existing_card.suit: is_duplicate = True break if not is_duplicate: unique_cards.append(card) return unique_cards class XTaskDataParser: def __init__(self, data_path: str = "ehs_data"): self.data_path = data_path self.decoder = OpenPQLDecoder() def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]: filepath = f"{self.data_path}/{filename}" try: raw_data = np.load(filepath) print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}") records = [] decode_errors = 0 for i, row in enumerate(raw_data): try: board_id = int(row['board']) player_id = int(row['player']) ehs = float(row['ehs']) # 解码公共牌 (5张) board_cards = self.decoder.decode_board_id(board_id, 5) # 解码玩家手牌 (2张) player_cards = self.decoder.decode_player_id(player_id) # 验证牌面不重复 all_cards = board_cards + player_cards unique_cards = self.decoder.decode_board_unique_card(all_cards) if len(unique_cards) != 7: print(f"记录 {i}: 存在重复牌面") print(f" Board: {[str(c) for c in board_cards]}") print(f" Player: {[str(c) for c in player_cards]}") # 创建完整记录 record = RiverEHSRecord( board_id=board_id, player_id=player_id, ehs=ehs, board_cards=board_cards, player_cards=player_cards ) records.append(record) except Exception as e: decode_errors += 1 if decode_errors <= 3: print(f"记录 {i} 解码失败: {e}") print(f" board_id={row['board']:016x}, player_id={row['player']:04x}") continue # 显示进度 if (i + 1) % 10000 == 0: success_rate = len(records) / (i + 1) * 100 print(f" 已处理 {i+1:,}/{len(raw_data):,} 条记录,成功率 {success_rate:.1f}%") total_records = len(raw_data) success_records = len(records) success_rate = success_records / total_records * 100 print(f"river_EHS解析完成:") print(f" 总记录数: {total_records:,}") print(f" 成功解码: {success_records:,} ({success_rate:.1f}%)") print(f" 解码失败: {decode_errors:,}") return records except FileNotFoundError: print(f"文件不存在: {filepath}") raise except Exception as e: print(f"解析river_EHS数据失败: {e}") raise def parse_turn_hist_with_cards(self, filename: str = "turn_hist.npy", max_records: int = 1000) -> List[TurnHistRecord]: filepath = f"{self.data_path}/{filename}" try: raw_data = np.load(filepath) print(f"加载turn_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录") records = [] decode_errors = 0 # 限制处理数量以避免内存问题 # todo:抽样优化 data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data for i, row in enumerate(data_to_process): try: board_id = int(row['board']) player_id = int(row['player']) bins = row['bins'].copy() # 解码公共牌 (4张) board_cards = self.decoder.decode_board_id(board_id, 4) # 解码玩家手牌 (2张) player_cards = self.decoder.decode_player_id(player_id) # 验证牌面不重复 all_cards = board_cards + player_cards unique_cards = self.decoder.decode_board_unique_card(all_cards) if len(unique_cards) != 6: decode_errors += 1 print(f"记录 {i}: 存在重复牌面") record = TurnHistRecord( board_id=board_id, player_id=player_id, bins=bins, # 存储30个bins的numpy数组 board_cards=board_cards, player_cards=player_cards ) records.append(record) except Exception as e: decode_errors += 1 if decode_errors <= 3: print(f"记录 {i} 解码失败: {e}") continue if (i + 1) % 100 == 0: print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...") print(f"turn_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录") return records except Exception as e: print(f" 解析turn_hist数据失败: {e}") raise def parse_flop_hist_with_cards(self, filename: str = "flop_hist.npy", max_records: int = 500) -> List[FlopHistRecord]: """解析flop_hist""" filepath = f"{self.data_path}/{filename}" try: raw_data = np.load(filepath) print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录") records = [] decode_errors = 0 data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data for i, row in enumerate(data_to_process): try: board_id = int(row['board']) player_id = int(row['player']) bins = row['bins'].copy() # 解码公共牌 (3张) board_cards = self.decoder.decode_board_id(board_id, 3) # 解码玩家手牌 (2张) player_cards = self.decoder.decode_player_id(player_id) # 验证牌面不重复 all_cards = board_cards + player_cards unique_cards = self.decoder.decode_board_unique_card(all_cards) if len(unique_cards) != 5: decode_errors += 1 print(f"记录 {i}: 存在重复牌面") record = FlopHistRecord( board_id=board_id, player_id=player_id, bins=bins, # 存储465个bins的numpy数组 board_cards=board_cards, player_cards=player_cards ) records.append(record) except Exception as e: decode_errors += 1 if decode_errors <= 3: print(f"记录 {i} 解码失败: {e}") continue if (i + 1) % 50 == 0: print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...") print(f"flop_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录") return records except Exception as e: print(f"解析flop_hist数据失败: {e}") raise def analyze_parsed_data(): print("开始解析 xtask 数据并解码牌面...\n") try: parser = XTaskDataParser(data_path="ehs_data") # 1. 解析river_EHS数据 print("=" * 60) print(" 解析river_EHS") print("=" * 60) river_records = parser.parse_river_ehs_with_cards() # 显示river_EHS样本 print(f"\nriver_EHS数据样本 (前10条):") for i, record in enumerate(river_records[:10]): board_str = " ".join(str(c) for c in record.board_cards) player_str = " ".join(str(c) for c in record.player_cards) print(f" {i+1:2d}. Board:[{board_str:14s}] Player:[{player_str:5s}] EHS:{record.ehs:.4f}") # 2. 解析turn_hist数据 print(f"\n" + "=" * 60) print(" 解析turn_hist数据") print("=" * 60) turn_records = parser.parse_turn_hist_with_cards(max_records=100) # 限制数量 # 显示turn_hist数据样本 print(f"\nturn_hist数据样本 (前5条):") for i, record in enumerate(turn_records[:5]): board_str = " ".join(str(c) for c in record.board_cards) player_str = " ".join(str(c) for c in record.player_cards) bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}" print(f" {i+1}. Board:[{board_str:11s}] Player:[{player_str:5s}] Bins:{bins_stats}") # 3. 解析flop_hist数据 print(f"\n" + "=" * 60) print(" 解析flop_hist数据") print("=" * 60) flop_records = parser.parse_flop_hist_with_cards(max_records=50) # 限制数量 # 显示flop_数据样本 print(f"\n flop_hist数据样本 (前5条):") for i, record in enumerate(flop_records[:5]): board_str = " ".join(str(c) for c in record.board_cards) player_str = " ".join(str(c) for c in record.player_cards) bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}" print(f" {i+1}. Board:[{board_str:8s}] Player:[{player_str:5s}] Bins:{bins_stats}") # 4. 统计摘要 print(f"\n" + "=" * 60) print("解析统计摘要") print("=" * 60) print(f"river_EHS记录: {len(river_records):,} 条") print(f"turn_hist记录: {len(turn_records):,} 条") print(f"flop_hist记录: {len(flop_records):,} 条") # EHS 统计 if river_records: ehs_values = [r.ehs for r in river_records] print(f"river_EHS统计:") print(f" 范围: [{min(ehs_values):.4f}, {max(ehs_values):.4f}]") print(f" 均值: {np.mean(ehs_values):.4f}") print(f" 标准差: {np.std(ehs_values):.4f}") print(f"\n数据解析完成!所有 Board ID 和 Player ID 已成功解码为具体牌面。") return { 'river_records': river_records, 'turn_records': turn_records, 'flop_records': flop_records } except Exception as e: print(f" 数据解析失败: {e}") return None