From ec50d2897fb32f7986ba5c002d34f5250832ff06 Mon Sep 17 00:00:00 2001 From: jianghaiying Date: Tue, 28 Oct 2025 15:38:25 +0800 Subject: [PATCH] cross-validation --- .gitignore | 2 +- cross_validation/__init__.py | 11 +- cross_validation/cross_validation.py | 391 ------------------------ cross_validation/parse_data.py | 429 --------------------------- cross_validation/validator.py | 216 ++++++++++++++ task5_main.py | 28 +- 6 files changed, 223 insertions(+), 854 deletions(-) delete mode 100644 cross_validation/cross_validation.py delete mode 100644 cross_validation/parse_data.py create mode 100644 cross_validation/validator.py diff --git a/.gitignore b/.gitignore index 505a3b1..27661eb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -# Python-generated files +ehs-data __pycache__/ *.py[oc] build/ diff --git a/cross_validation/__init__.py b/cross_validation/__init__.py index 69df6f0..9f1f007 100644 --- a/cross_validation/__init__.py +++ b/cross_validation/__init__.py @@ -1,8 +1,5 @@ -""" -Cross Validation module for EHS winrate data validation -""" +from .validator import cross_validate_main -from .cross_validation import DataValidator -from .parse_data import XTaskDataParser - -__all__ = ['DataValidator', 'XTaskDataParser'] \ No newline at end of file +__all__ = [ + 'cross_validate_main' +] \ No newline at end of file diff --git a/cross_validation/cross_validation.py b/cross_validation/cross_validation.py deleted file mode 100644 index eb9e6d9..0000000 --- a/cross_validation/cross_validation.py +++ /dev/null @@ -1,391 +0,0 @@ -#!/usr/bin/env python3 -import numpy as np -from typing import List, Dict, Tuple -from scipy.stats import wasserstein_distance -import sys - -from .parse_data import XTaskDataParser -from shortdeck.gen_hist import ShortDeckHistGenerator -import matplotlib.pyplot as plt - - -class DataValidator: - - - def __init__(self, data_path: str = "ehs_data"): - self.parser = XTaskDataParser(data_path) - self.generator = ShortDeckHistGenerator() - print(" DataValidator初始化完成") - print(f" 生成器短牌型大小: {len(self.generator.full_deck)}") - - def validate_river_samples(self, max_samples: int = 20) : - # print(f"\n 验证river_EHS样本 (最大样本数: {max_samples})") - - try: - print(" 解析导出的river数据...") - print('='*60) - river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples) - - if not river_records: - return {'error': '没有解析到river记录', 'success': False} - - sample_records = river_records - print(f" 选择 {len(sample_records)} 个样本进行验证") - - matches = 0 - errors = 0 - differences = [] - - for i, record in enumerate(sample_records): - try: - player_cards = record.player_cards - board_cards = record.board_cards - src_river_ehs = record.ehs - - - cur_river_ehs = self.generator.generate_river_ehs(player_cards, board_cards) - - - ehs_difference = abs(src_river_ehs - cur_river_ehs) - - - player_str = " ".join(str(c) for c in player_cards) - board_str = " ".join(str(c) for c in board_cards) - print(f" 样本 {i+1}: [{player_str}] + [{board_str}]") - print(f" 原始EHS: {src_river_ehs:.6f}") - print(f" 重算EHS: {cur_river_ehs:.6f}") - print(f" 差异: {ehs_difference:.6f}") - - # 判断匹配 (允许小的数值差异) - tolerance = 1e-6 - if ehs_difference < tolerance: - matches += 1 - else: - differences.append(ehs_difference) - - except Exception as e: - errors += 1 - if errors <= 3: - print(f" 样本 {i+1} 计算失败: {e}") - - # 统计结果 - total_samples = len(sample_records) - match_rate = matches / total_samples if total_samples > 0 else 0 - mean_diff = np.mean(differences) if differences else 0 - max_diff = np.max(differences) if differences else 0 - - result = { - 'total_samples': total_samples, - 'matches': matches, - 'match_rate': match_rate, - 'mean_difference': mean_diff, - 'max_difference': max_diff, - 'errors': errors, - 'success': match_rate > 0.8 and mean_diff < 0.05 - } - - print(f" River验证完成:") - print('='*60) - print(f" 匹配数: {matches}/{total_samples} ({match_rate:.1%})") - print(f" 平均差异: {mean_diff:.6f}") - print(f" 最大差异: {max_diff:.6f}") - - return result - - except Exception as e: - print(f" River验证失败: {e}") - return {'error': str(e), 'success': False} - - def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards): - - player_str = " ".join(str(c) for c in player_cards) - board_str = " ".join(str(c) for c in board_cards) - print("="*60) - print(f"样本 {i+1}: [{player_str}] + [{board_str}]") - print("bin src src_norm cur cur_norm") - src_hist = np.array(src_hist) - cur_hist = np.array(cur_hist) - src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist - cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist - for i in range(min(len(src_hist), 30)): - if src_hist[i] > 0 or cur_hist[i] > 0: - print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}") - - def validate_turn_samples(self, max_samples: int = 10): - - # print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})") - - try: - - print(" 解析导出的Turn数据...") - print('='*60) - print('='*60) - turn_records = self.parser.parse_turn_hist_with_cards(max_records=max_samples) - - if not turn_records: - return {'error': '没有解析到Turn记录', 'success': False} - - print(f" 解析到 {len(turn_records)} 个Turn样本") - - - low_emd_count = 0 - emd_distances = [] - errors = 0 - - for i, record in enumerate(turn_records): - try: - - player_cards = record.player_cards - board_cards = record.board_cards - src_hist = np.array(record.bins) - - cur_hist = self.generator.generate_turn_histogram( - player_cards, board_cards, num_bins=len(src_hist) - ) - cur_hist = np.array(cur_hist) - - # 归一化 - src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist - cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist - - # 计算EMD距离 - emd_dist = wasserstein_distance( - range(len(src_hist_norm)), - range(len(cur_hist_norm)), - src_hist_norm, - cur_hist_norm - ) - - low_emd_count += 1 if emd_dist < 0.2 else 0 - errors += 1 if emd_dist >= 0.2 else 0 - emd_distances.append(emd_dist) - - self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards) - - - - # 画图显示 - plt.plot(src_hist, label='src', marker='o') - plt.plot(cur_hist, label='cur', marker='x') - plt.title(f"turn_hist_emd={emd_dist:.6f}") - plt.xlabel("Bins") - plt.ylabel("Frequency") - plt.legend() - plt.show() - - except Exception as e: - errors += 1 - if errors <= 3: - print(f" 样本 {i+1} 计算失败: {e}") - - # 统计结果 - total_samples = len(turn_records) - low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0 - mean_emd = np.mean(emd_distances) if emd_distances else float('inf') - - result = { - 'total_samples': total_samples, - 'low_emd_count': low_emd_count, - 'low_emd_rate': low_emd_rate, - 'mean_emd_distance': mean_emd, - 'emd_distances': emd_distances, - 'errors': errors, - 'success': low_emd_rate > 0.6 and mean_emd < 0.5 - } - - print(f" Turn验证完成:") - print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})") - print(f" 平均EMD: {mean_emd:.6f}") - - return result - - except Exception as e: - print(f" Turn验证失败: {e}") - return {'error': str(e), 'success': False} - - def validate_flop_samples(self, max_samples: int = 5): - # print(f"\n 验证Flop直方图样本 (最大样本数: {max_samples})") - - try: - - print(" 解析导出Flop数据...") - print('='*60) - print('='*60) - print('='*60) - flop_records = self.parser.parse_flop_hist_with_cards(max_records=max_samples) - - if not flop_records: - return {'error': '没有解析到Flop记录', 'success': False} - - print(f" 解析到 {len(flop_records)} 个Flop样本") - - low_emd_count = 0 - emd_distances = [] - errors = 0 - - for i, record in enumerate(flop_records): - try: - print(f" 处理样本 {i+1}/{len(flop_records)}...") - - player_cards = record.player_cards - board_cards = record.board_cards - src_hist = np.array(record.bins) - - cur_hist = self.generator.generate_flop_histogram( - player_cards, board_cards, num_bins=len(src_hist) - ) - cur_hist = np.array(cur_hist) - - - src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist - cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist - - # 计算EMD距离 - emd_dist = wasserstein_distance( - range(len(src_hist_norm)), - range(len(cur_hist_norm)), - src_hist_norm, - cur_hist_norm - ) - emd_distances.append(emd_dist) - - low_emd_count += 1 if emd_dist < 10 else 0 - errors += 1 if emd_dist >= 10 else 0 - - # 显示详细信息 - player_str = " ".join(str(c) for c in player_cards) - board_str = " ".join(str(c) for c in board_cards) - print(f" 样本 {i+1}: [{player_str}] + [{board_str}]") - print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}") - print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}") - print(f" 归一化后EMD距离: {emd_dist:.6f}") - print("bin src src_norm cur cur_norm") - for i in range(min(len(src_hist), 30)): - if src_hist[i] > 0 or cur_hist[i] > 0: - print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}") - - self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards) - - # 画图显示 - plt.plot(src_hist, label='src', marker='o') - plt.plot(cur_hist, label='cur', marker='x') - plt.title(f"flop_hist_emd={emd_dist:.6f}") - plt.xlabel("Bins") - plt.ylabel("Frequency") - plt.legend() - plt.show() - - except Exception as e: - errors += 1 - if errors <= 3: - print(f" 样本 {i+1} 计算失败: {e}") - - # 统计结果 - total_samples = len(flop_records) - low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0 - mean_emd = np.mean(emd_distances) if emd_distances else float('inf') - - result = { - 'total_samples': total_samples, - 'low_emd_count': low_emd_count, - 'low_emd_rate': low_emd_rate, - 'mean_emd_distance': mean_emd, - 'emd_distances': emd_distances, - 'errors': errors, - 'success': low_emd_rate > 0.4 and mean_emd < 0.5 - } - - print(f" Flop验证完成:") - print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})") - print(f" 平均EMD: {mean_emd:.6f}") - - return result - - except Exception as e: - print(f" Flop验证失败: {e}") - return {'error': str(e), 'success': False} - - def run_full_validation(self, river_samples: int = 20, turn_samples: int = 10, flop_samples: int = 5) -> Dict: - - print(" 导出数据EHS验证") - print("*"*60) - - # 执行各阶段验证 - results = {} - results['river'] = self.validate_river_samples(river_samples) - results['turn'] = self.validate_turn_samples(turn_samples) - results['flop'] = self.validate_flop_samples(flop_samples) - - print(f"\n{'='*60}") - print(" 验证完毕") - print(f"{'='*60}") - - passed_stages = 0 - total_stages = 3 - - # River结果 - print(f"\n RIVER阶段:") - if 'error' not in results['river']: - status = " 通过" if results['river']['success'] else " 失败" - print(f" 验证结果: {status}") - print(f" 样本数量: {results['river']['total_samples']}") - print(f" 匹配率: {results['river']['match_rate']:.1%}") - print(f" 平均差异: {results['river']['mean_difference']:.6f}") - if results['river']['success']: - passed_stages += 1 - else: - print(f" 错误: {results['river']['error']}") - - # Turn结果 - print(f"\n TURN阶段:") - if 'error' not in results['turn']: - status = " 通过" if results['turn']['success'] else " 失败" - print(f" 验证结果: {status}") - print(f" 样本数量: {results['turn']['total_samples']}") - print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}") - print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}") - print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}") - if results['turn']['success']: - passed_stages += 1 - else: - print(f" 错误: {results['turn']['error']}") - - # Flop结果 - print(f"\n FLOP阶段:") - if 'error' not in results['flop']: - status = " 通过" if results['flop']['success'] else " 失败" - print(f" 验证结果: {status}") - print(f" 样本数量: {results['flop']['total_samples']}") - print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}") - print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}") - print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}") - if results['flop']['success']: - passed_stages += 1 - else: - print(f" 错误: {results['flop']['error']}") - - # 总体结果 - passed_stages = 0 - total_stages = 3 - - if results.get('river') and results['river'].get('success', False): - passed_stages += 1 - if results.get('turn') and results['turn'].get('success', False): - passed_stages += 1 - if results.get('flop') and results['flop'].get('success', False): - passed_stages += 1 - - overall_rate = passed_stages / total_stages - # print(f"\n 总体验证通过率: {passed_stages}/{total_stages} ({overall_rate:.1%})") - - # if overall_rate >= 0.7: - # print(" 数据验证成功!短牌型生成器与解析数据基本一致。") - # else: - # print(" 验证存在问题,生成器可能与实际数据不匹配,需要调试。") - - return { - 'results': results, - 'passed_stages': passed_stages, - 'total_stages': total_stages, - 'overall_success': overall_rate >= 0.7 - } diff --git a/cross_validation/parse_data.py b/cross_validation/parse_data.py deleted file mode 100644 index 7729117..0000000 --- a/cross_validation/parse_data.py +++ /dev/null @@ -1,429 +0,0 @@ -import numpy as np -import random -from typing import List, Dict, Tuple, Optional -from dataclasses import dataclass -from poker.card import Card, ShortDeckRank, Suit - -@dataclass -class RiverEHSRecord: - """river_EHS """ - board_id: int - player_id: int - ehs: float - board_cards: List[Card] # 5张公共牌 - player_cards: List[Card] # 2张手牌 - -@dataclass -class TurnHistRecord: - """turn_hist""" - board_id: int - player_id: int - bins: np.ndarray - board_cards: List[Card] # 4张公共牌 - player_cards: List[Card] # 2张手牌 - -@dataclass -class FlopHistRecord: - """flop_hist""" - board_id: int - player_id: int - bins: np.ndarray - board_cards: List[Card] # 3张公共牌 - player_cards: List[Card] # 2张手牌 - -class OpenPQLDecoder: - """open-pql 完整解码""" - - def __init__(self): - # Card64 - self.OFFSET_SUIT = 16 # 每个suit占16位 - self.OFFSET_S = 0 # S: [15:0] - self.OFFSET_H = 16 # H: [31:16] - self.OFFSET_D = 32 # D: [47:32] - self.OFFSET_C = 48 # C: [63:48] - - # open-pql 到 短牌型(6-A,36张牌) - self.opql_to_poker_rank = { - # 0: Rank.TWO, - # 1: Rank.THREE, - # 2: Rank.FOUR, - # 3: Rank.FIVE, - 4: ShortDeckRank.SIX, - 5: ShortDeckRank.SEVEN, - 6: ShortDeckRank.EIGHT, - 7: ShortDeckRank.NINE, - 8: ShortDeckRank.TEN, - 9: ShortDeckRank.JACK, - 10: ShortDeckRank.QUEEN, - 11: ShortDeckRank.KING, - 12: ShortDeckRank.ACE, - } - - self.opql_to_poker_suit = { - 0: Suit.SPADES, # S -> SPADES - 1: Suit.HEARTS, # H -> HEARTS - 2: Suit.DIAMONDS, # D -> DIAMONDS - 3: Suit.CLUBS, # C -> CLUBS - } - - print("========== OpenPQLDecoder (短牌型36张牌) ===============") - - def u64_from_ranksuit(self, rank, suit) -> int: - """对应 Card64::u64_from_ranksuit_i8""" - return 1 << rank << (suit * self.OFFSET_SUIT) - - def decode_card64(self, card64_u64) -> List[Card]: - """从 Card64 的 u64 表示解码出所有牌""" - cards = [] - - # 按照 Card64 解析每个suit - # 每个suit16位 - for opql_suit in range(4): # 0,1,2,3 - suit_offset = opql_suit * self.OFFSET_SUIT - suit_bits = (card64_u64 >> suit_offset) & 0xFFFF - - # 检查这个suit的每个rank - for opql_rank in range(13): # 0-12 - if suit_bits & (1 << opql_rank): - poker_rank = self.opql_to_poker_rank[opql_rank] - poker_suit = self.opql_to_poker_suit[opql_suit] - cards.append(Card(poker_rank, poker_suit)) - - cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value)) - return cards - - def decode_hand_n_2(self, id_u16) -> List[Card]: - """解码 Hand<2> 的 u16 ID""" - card0_u8 = id_u16 & 0xFF # 低8位 - card1_u8 = (id_u16 >> 8) & 0xFF # 高8位 - - try: - card0 = self._card_from_u8(card0_u8) - card1 = self._card_from_u8(card1_u8) - - cards = [card0, card1] - cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value)) - - return cards - - except (ValueError, TypeError) as e: - raise ValueError(f"解码 Hand<2> 失败: id_u16={id_u16:04x}, card0_u8={card0_u8:02x}, card1_u8={card1_u8:02x}: {e}") - - def _card_from_u8(self, v) -> Card: - """ - 对应 open-pql Card::from_u8() - 单张牌decode - """ - SHIFT_SUIT = 4 - opql_rank = v & 0b1111 # 低4位 - opql_suit = v >> SHIFT_SUIT # 高4位 - - if opql_rank not in self.opql_to_poker_rank: - raise ValueError(f"无效的rank: {opql_rank}") - if opql_suit not in self.opql_to_poker_suit: - raise ValueError(f"无效的suit: {opql_suit}") - - poker_rank = self.opql_to_poker_rank[opql_rank] - poker_suit = self.opql_to_poker_suit[opql_suit] - - return Card(poker_rank, poker_suit) - - def decode_board_id(self, board_id, num_cards) -> List[Card]: - """解码公共牌 ID""" - if num_cards == 2: - # 2张牌使用 Hand<2> 的编码方式 - return self.decode_hand_n_2(board_id) - else: - # 3, 4, 5张公共牌使用 Card64 的编码方式 - cards = self.decode_card64(board_id) - - if len(cards) != num_cards: - raise ValueError(f"解码出 {len(cards)} 张牌,期望 {num_cards} 张,board_id={board_id:016x}") - - return cards - - def decode_player_id(self, player_id) -> List[Card]: - """解码玩家手牌 ID (2张牌)""" - return self.decode_hand_n_2(player_id) - - def decode_board_unique_card(self, all_cards) -> List[Card]: - """验证并返回不重复的牌面""" - unique_cards = [] - for card in all_cards: - is_duplicate = False - for existing_card in unique_cards: - if card.rank == existing_card.rank and card.suit == existing_card.suit: - is_duplicate = True - break - if not is_duplicate: - unique_cards.append(card) - return unique_cards - - -class XTaskDataParser: - def __init__(self, data_path: str = "ehs_data"): - self.data_path = data_path - self.decoder = OpenPQLDecoder() - - def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]: - filepath = f"{self.data_path}/{filename}" - - try: - raw_data = np.load(filepath) - print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}") - # 抽样 - data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data))) - - records = [] - decode_errors = 0 - - for i, row in enumerate(data_to_process): - try: - board_id = int(row['board']) - player_id = int(row['player']) - ehs = float(row['ehs']) - - # 解码公共牌 (5张) - board_cards = self.decoder.decode_board_id(board_id, 5) - - # 解码玩家手牌 (2张) - player_cards = self.decoder.decode_player_id(player_id) - - # 验证牌面不重复 - all_cards = board_cards + player_cards - - unique_cards = self.decoder.decode_board_unique_card(all_cards) - if len(unique_cards) != 7: - print(f"记录 {i}: 存在重复牌面") - print(f" Board: {[str(c) for c in board_cards]}") - print(f" Player: {[str(c) for c in player_cards]}") - - # 创建完整记录 - record = RiverEHSRecord( - board_id=board_id, - player_id=player_id, - ehs=ehs, - board_cards=board_cards, - player_cards=player_cards - ) - - records.append(record) - - except Exception as e: - decode_errors += 1 - if decode_errors <= 3: - print(f"记录 {i} 解码失败: {e}") - print(f" board_id={row['board']:016x}, player_id={row['player']:04x}") - continue - - # 显示进度 - if (i + 1) % 10000 == 0: - success_rate = len(records) / (i + 1) * 100 - print(f" 已处理 {i+1:,}/{len(raw_data):,} 条记录,成功率 {success_rate:.1f}%") - - total_records = len(raw_data) - success_records = len(records) - success_rate = success_records / total_records * 100 - - print(f"river_EHS解析完成:") - print(f" 总记录数: {total_records:,}") - print(f" 成功解码: {success_records:,} ({success_rate:.1f}%)") - print(f" 解码失败: {decode_errors:,}") - - return records - - except FileNotFoundError: - print(f"文件不存在: {filepath}") - raise - except Exception as e: - print(f"解析river_EHS数据失败: {e}") - raise - - def parse_turn_hist_with_cards(self, filename: str = "turn_hist.npy", max_records: int = 1000) -> List[TurnHistRecord]: - filepath = f"{self.data_path}/{filename}" - - try: - raw_data = np.load(filepath) - print(f"加载turn_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录") - - records = [] - decode_errors = 0 - # 抽样 - data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data))) - - for i, row in enumerate(data_to_process): - try: - board_id = int(row['board']) - player_id = int(row['player']) - bins = row['bins'].copy() - - # 解码公共牌 (4张) - board_cards = self.decoder.decode_board_id(board_id, 4) - - # 解码玩家手牌 (2张) - player_cards = self.decoder.decode_player_id(player_id) - - # 验证牌面不重复 - all_cards = board_cards + player_cards - - unique_cards = self.decoder.decode_board_unique_card(all_cards) - if len(unique_cards) != 6: - decode_errors += 1 - print(f"记录 {i}: 存在重复牌面") - - record = TurnHistRecord( - board_id=board_id, - player_id=player_id, - bins=bins, # 存储30个bins的numpy数组 - board_cards=board_cards, - player_cards=player_cards - ) - - records.append(record) - - except Exception as e: - decode_errors += 1 - if decode_errors <= 3: - print(f"记录 {i} 解码失败: {e}") - continue - - if (i + 1) % 100 == 0: - print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...") - - print(f"turn_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录") - return records - - except Exception as e: - print(f" 解析turn_hist数据失败: {e}") - raise - - def parse_flop_hist_with_cards(self, filename: str = "flop_hist.npy", max_records: int = 500) -> List[FlopHistRecord]: - """解析flop_hist""" - filepath = f"{self.data_path}/{filename}" - - try: - raw_data = np.load(filepath) - print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录") - records = [] - decode_errors = 0 - data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data))) - - for i, row in enumerate(data_to_process): - try: - board_id = int(row['board']) - player_id = int(row['player']) - bins = row['bins'].copy() - - # 解码公共牌 (3张) - board_cards = self.decoder.decode_board_id(board_id, 3) - - # 解码玩家手牌 (2张) - player_cards = self.decoder.decode_player_id(player_id) - - # 验证牌面不重复 - all_cards = board_cards + player_cards - unique_cards = self.decoder.decode_board_unique_card(all_cards) - if len(unique_cards) != 5: - decode_errors += 1 - print(f"记录 {i}: 存在重复牌面") - - record = FlopHistRecord( - board_id=board_id, - player_id=player_id, - bins=bins, # 存储465个bins的numpy数组 - board_cards=board_cards, - player_cards=player_cards - ) - - records.append(record) - - except Exception as e: - decode_errors += 1 - if decode_errors <= 3: - print(f"记录 {i} 解码失败: {e}") - continue - - if (i + 1) % 50 == 0: - print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...") - - print(f"flop_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录") - return records - - except Exception as e: - print(f"解析flop_hist数据失败: {e}") - raise - -def analyze_parsed_data(): - - print("开始解析 xtask 数据并解码牌面...\n") - - try: - parser = XTaskDataParser(data_path="ehs_data") - - # 1. 解析river_EHS数据 - print("=" * 60) - print(" 解析river_EHS") - print("=" * 60) - river_records = parser.parse_river_ehs_with_cards() - - # 显示river_EHS样本 - print(f"\nriver_EHS数据样本 (前10条):") - for i, record in enumerate(river_records[:10]): - board_str = " ".join(str(c) for c in record.board_cards) - player_str = " ".join(str(c) for c in record.player_cards) - print(f" {i+1:2d}. Board:[{board_str:14s}] Player:[{player_str:5s}] EHS:{record.ehs:.4f}") - - # 2. 解析turn_hist数据 - print(f"\n" + "=" * 60) - print(" 解析turn_hist数据") - print("=" * 60) - turn_records = parser.parse_turn_hist_with_cards(max_records=100) # 限制数量 - # 显示turn_hist数据样本 - print(f"\nturn_hist数据样本 (前5条):") - for i, record in enumerate(turn_records[:5]): - board_str = " ".join(str(c) for c in record.board_cards) - player_str = " ".join(str(c) for c in record.player_cards) - bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}" - print(f" {i+1}. Board:[{board_str:11s}] Player:[{player_str:5s}] Bins:{bins_stats}") - - # 3. 解析flop_hist数据 - print(f"\n" + "=" * 60) - print(" 解析flop_hist数据") - print("=" * 60) - flop_records = parser.parse_flop_hist_with_cards(max_records=50) # 限制数量 - - # 显示flop_数据样本 - print(f"\n flop_hist数据样本 (前5条):") - for i, record in enumerate(flop_records[:5]): - board_str = " ".join(str(c) for c in record.board_cards) - player_str = " ".join(str(c) for c in record.player_cards) - bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}" - print(f" {i+1}. Board:[{board_str:8s}] Player:[{player_str:5s}] Bins:{bins_stats}") - - # 4. 统计摘要 - print(f"\n" + "=" * 60) - print("解析统计摘要") - print("=" * 60) - print(f"river_EHS记录: {len(river_records):,} 条") - print(f"turn_hist记录: {len(turn_records):,} 条") - print(f"flop_hist记录: {len(flop_records):,} 条") - - # EHS 统计 - if river_records: - ehs_values = [r.ehs for r in river_records] - print(f"river_EHS统计:") - print(f" 范围: [{min(ehs_values):.4f}, {max(ehs_values):.4f}]") - print(f" 均值: {np.mean(ehs_values):.4f}") - print(f" 标准差: {np.std(ehs_values):.4f}") - - print(f"\n数据解析完成!所有 Board ID 和 Player ID 已成功解码为具体牌面。") - - return { - 'river_records': river_records, - 'turn_records': turn_records, - 'flop_records': flop_records - } - - except Exception as e: - print(f" 数据解析失败: {e}") - return None diff --git a/cross_validation/validator.py b/cross_validation/validator.py new file mode 100644 index 0000000..6106bd5 --- /dev/null +++ b/cross_validation/validator.py @@ -0,0 +1,216 @@ +import itertools + +import numpy as np +import random +from collections import defaultdict +from collections.abc import Iterable +from pathlib import Path + +from poker import Suit, Card +from shortdeck import ShortDeckHandEvaluator as HE +from shortdeck import ShortDeckRank as SDR + +data_path = Path(".") / "ehs-data" +np_river = np.load(data_path / "river_ehs_sd.npy") +np_turn = np.load(data_path / "turn_hist_sd.npy") +np_flop = np.load(data_path / "flop_hist_sd.npy") + +cards = [Card(r, s) for r in SDR for s in Suit] + +CARD_BITS = 6 +class SuitMapping: + def __init__(self): + self.mapping = {} + self.suits = list(reversed(Suit)) + + def map_suit(self, s: Suit) -> Suit: + if s not in self.mapping: + self.mapping[s] = self.suits.pop() + + return self.mapping[s] + +class EhsCache: + def __init__(self): + self.cache = defaultdict(lambda: defaultdict(dict)) + + def _set_keys(self, flop, player): + suit_map = SuitMapping() + complex_cards = player+flop + iso_complex = to_iso(complex_cards, suit_map) + complex_key = cards_to_u32(iso_complex) + return complex_key + # 全部存下来计算耗时太大了。。嗯。。 + # todo:不存储直接抽样计算 + def store_river_ehs(self, player, board, ehs): + complex_key = self._set_keys(board[:3],player) + turn_idx = card_index(board[3]) + river_idx = card_index(board[4]) + self.cache[complex_key][turn_idx][river_idx] = ehs + + def get_turn_hist(self, player, flop, turn): + complex_key = self._set_keys(flop, player) + turn_idx = card_index(turn) + turn_hist = self.cache[complex_key][turn_idx] + return list(turn_hist.values()) if turn_hist else None + + def get_flop_hist(self, player, flop): + complex_key = self._set_keys(flop, player) + all_ehs = [] + player_data = self.cache[complex_key] + for turn_idx in player_data: + for river_idx in player_data[turn_idx]: + all_ehs.append(player_data[turn_idx][river_idx]) + return all_ehs if len(all_ehs) == 465 else None + +def get_rank_idx(rank: SDR) -> int: + rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN, + SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE] + return rank_order.index(rank) + +def get_suit_idx(suit: Suit) -> int: + suit_order = [Suit.SPADES, Suit.HEARTS, Suit.DIAMONDS, Suit.CLUBS] + return suit_order.index(suit) + +def card_index(card: Card) -> int: + return (get_rank_idx(card.rank) + 4) * 4 + get_suit_idx(card.suit) + +Card.__eq__ = lambda a, b: (a.rank == b.rank) and (a.suit == b.suit) +Card.__hash__ = lambda a: hash((get_rank_idx(a.rank), get_suit_idx(a.suit))) + +def cards_to_u32(cards: list[Card]) -> int: + res = 0 + for i, card in enumerate(cards): + bits = card_index(card) & 0x3F + res |= bits << (i * CARD_BITS) + return res + +def to_iso(cards: list[Card], mapping: SuitMapping) -> list[Card]: + def count_suit(card: Card) -> int: + return sum(1 for other in cards if other.suit == card.suit) + + sorted_cards = sorted( + cards, + key=lambda c: (count_suit(c), get_rank_idx(c.rank), get_suit_idx(c.suit)) + ) + res = [] + for card in sorted_cards: + mapped_suit = mapping.map_suit(card.suit) + res.append(Card(card.rank, mapped_suit)) + + return sorted(res, key=lambda c: (get_rank_idx(c.rank), get_suit_idx(c.suit))) + +def cards_to_u16(cards: list[Card]) -> int: + res = 0 + for i, card in enumerate(cards): + bits = card_index(card) & 0x3F + res |= bits << (i * CARD_BITS) + return res + +def calc_river_ehs(board: list[Card], player: list[Card]) -> float: + player_hand = [*board, *player] + player_ranking = HE.evaluate_hand(player_hand) + + acc = 0 + sum = 0 + for other in itertools.combinations(cards, 2): + if set(other) & set(player_hand): + continue + if set(other) & set(board): + continue + other_ranking = HE.evaluate_hand([*board, *other]) + if player_ranking == other_ranking: + acc += 1 + elif player_ranking > other_ranking: + acc += 2 + sum += 2 + return acc / sum + +def get_data(board: list[Card], player: list[Card]): + def _get_data(data, board: list[Card], player: list[Card]): + suit_map = SuitMapping() + + iso_board = to_iso(board, suit_map) + iso_player = to_iso(player, suit_map) + + mask_board = data["board"] == cards_to_u32(iso_board) + mask_player = data["player"] == cards_to_u16(iso_player) + return data[mask_board & mask_player][0][2] + match len(board): + case 3: + return _get_data(np_flop, board, player) + case 4: + return _get_data(np_turn, board, player) + case 5: + return _get_data(np_river, board, player) + case _: + raise NotImplementedError + +def euclidean_dist(left, right): + if isinstance(left, Iterable): + v1 = np.sort(np.array(left, dtype=np.float32)) + v2 = np.sort(np.array(right, dtype=np.float32)) + return np.linalg.norm(v2 - v1) + else: + return np.abs(left - right) ** 2 + +def compare_data(sampled, board, player): + err_count = 0 + d = euclidean_dist(get_data(board, player), sampled) + if not np.isclose(d, 0.0): + print(f"[{''.join(map(str, board))} {''.join(map(str, player))}]: {d}") + err_count += 1 + return err_count + +card_ehs = defaultdict(dict) + +def validate_river(): + validated_count = 0 + error_count = 0 + for river_combo in itertools.combinations(cards, 5): + board = list(river_combo) + unused_cards = [c for c in cards if c not in board] + for player_combo in itertools.combinations(unused_cards, 2): + player = list(player_combo) + ehs = calc_river_ehs(board, player) + ehs_stored.store_river_ehs(player, board, ehs) + error_count = compare_data(ehs, board, player) + validated_count += 1 + print(".", end="", flush=True) + print(f"river validate count {validated_count}") + print(f"Validated river hands: {validated_count}, Errors: {error_count}") + + +def validate_turn(): + validated_count = 0 + error_count = 0 + for turn_combo in itertools.combinations(cards, 4): + turn = list(turn_combo) + unused_cards = [c for c in cards if c not in turn] + for player_combo in itertools.combinations(unused_cards, 2): + player = list(player_combo) + turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3]) + error_count += compare_data(turn_hist, turn, player) + validated_count += 1 + print(".", end="", flush=True) + print(f"Validated turn hands: {validated_count}, Errors: {error_count}") + +def validate_flop(): + sample = random.sample(cards, 5) + flop = sample[2:] + player = sample[:2] + + flop_hist = ehs_stored.get_flop_hist(player, flop) + if flop_hist is None: + return + compare_data(flop_hist, flop, player) + + +ehs_stored = EhsCache() + +def cross_validate_main(): + validate_river() + validate_turn() + validate_flop() + +if __name__ == "__main__": + cross_validate_main() \ No newline at end of file diff --git a/task5_main.py b/task5_main.py index 99b9f6a..20b2867 100644 --- a/task5_main.py +++ b/task5_main.py @@ -1,32 +1,8 @@ -from cross_validation import DataValidator +from cross_validation import cross_validate_main import sys def main(): - """主函数""" - import argparse - - parser = argparse.ArgumentParser( - description='从xtask导出数据中抽取牌面进行EHS验证', - formatter_class=argparse.RawDescriptionHelpFormatter - ) - - parser.add_argument('--river-samples', type=int, default=10, help='River样本数 (默认: 10)') - parser.add_argument('--turn-samples', type=int, default=5, help='Turn样本数 (默认: 5)') - parser.add_argument('--flop-samples', type=int, default=3, help='Flop样本数 (默认: 3)') - parser.add_argument('--data-path', type=str, default='ehs_data', help='数据路径 (默认: ehs_data)') - - args = parser.parse_args() - - - validator = DataValidator(data_path=args.data_path) - results = validator.run_full_validation( - river_samples=args.river_samples, - turn_samples=args.turn_samples, - flop_samples=args.flop_samples - ) - - return 0 if results['overall_success'] else 1 - + return cross_validate_main() if __name__ == '__main__': sys.exit(main()) \ No newline at end of file