diff --git a/cross_validation/cross_validation.py b/cross_validation/cross_validation.py index 15d8bf7..eb9e6d9 100644 --- a/cross_validation/cross_validation.py +++ b/cross_validation/cross_validation.py @@ -6,6 +6,7 @@ import sys from .parse_data import XTaskDataParser from shortdeck.gen_hist import ShortDeckHistGenerator +import matplotlib.pyplot as plt class DataValidator: @@ -23,13 +24,12 @@ class DataValidator: try: print(" 解析导出的river数据...") print('='*60) - river_records = self.parser.parse_river_ehs_with_cards() + river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples) if not river_records: return {'error': '没有解析到river记录', 'success': False} - - sample_records = np.random.choice(river_records, size=max_samples, replace=False) + sample_records = river_records print(f" 选择 {len(sample_records)} 个样本进行验证") matches = 0 @@ -95,6 +95,21 @@ class DataValidator: except Exception as e: print(f" River验证失败: {e}") return {'error': str(e), 'success': False} + + def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards): + + player_str = " ".join(str(c) for c in player_cards) + board_str = " ".join(str(c) for c in board_cards) + print("="*60) + print(f"样本 {i+1}: [{player_str}] + [{board_str}]") + print("bin src src_norm cur cur_norm") + src_hist = np.array(src_hist) + cur_hist = np.array(cur_hist) + src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist + cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist + for i in range(min(len(src_hist), 30)): + if src_hist[i] > 0 or cur_hist[i] > 0: + print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}") def validate_turn_samples(self, max_samples: int = 10): @@ -140,20 +155,24 @@ class DataValidator: src_hist_norm, cur_hist_norm ) - - is_low_emd = emd_dist < 0.2 - if is_low_emd: - low_emd_count += 1 - - # 显示详细信息(前3个样本) - if i < 3: - player_str = " ".join(str(c) for c in player_cards) - board_str = " ".join(str(c) for c in board_cards) - print(f" 样本 {i+1}: [{player_str}] + [{board_str}]") - print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}") - print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}") - print(f" 归一化后EMD距离: {emd_dist:.6f}") - + + low_emd_count += 1 if emd_dist < 0.2 else 0 + errors += 1 if emd_dist >= 0.2 else 0 + emd_distances.append(emd_dist) + + self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards) + + + + # 画图显示 + plt.plot(src_hist, label='src', marker='o') + plt.plot(cur_hist, label='cur', marker='x') + plt.title(f"turn_hist_emd={emd_dist:.6f}") + plt.xlabel("Bins") + plt.ylabel("Frequency") + plt.legend() + plt.show() + except Exception as e: errors += 1 if errors <= 3: @@ -228,11 +247,10 @@ class DataValidator: src_hist_norm, cur_hist_norm ) - # emd_distances.append(emd_dist) - - # is_low_emd = emd_dist < 0.2 # EMD阈值 - # if is_low_emd: - # low_emd_count += 1 + emd_distances.append(emd_dist) + + low_emd_count += 1 if emd_dist < 10 else 0 + errors += 1 if emd_dist >= 10 else 0 # 显示详细信息 player_str = " ".join(str(c) for c in player_cards) @@ -241,7 +259,22 @@ class DataValidator: print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}") print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}") print(f" 归一化后EMD距离: {emd_dist:.6f}") - + print("bin src src_norm cur cur_norm") + for i in range(min(len(src_hist), 30)): + if src_hist[i] > 0 or cur_hist[i] > 0: + print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}") + + self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards) + + # 画图显示 + plt.plot(src_hist, label='src', marker='o') + plt.plot(cur_hist, label='cur', marker='x') + plt.title(f"flop_hist_emd={emd_dist:.6f}") + plt.xlabel("Bins") + plt.ylabel("Frequency") + plt.legend() + plt.show() + except Exception as e: errors += 1 if errors <= 3: @@ -276,7 +309,6 @@ class DataValidator: print(" 导出数据EHS验证") print("*"*60) - print("验证策略: 从xtask导出数据中抽取牌面 → 短牌型生成器重计算 → 比较一致性") # 执行各阶段验证 results = {} @@ -312,6 +344,7 @@ class DataValidator: print(f" 样本数量: {results['turn']['total_samples']}") print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}") print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}") + print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}") if results['turn']['success']: passed_stages += 1 else: @@ -325,6 +358,7 @@ class DataValidator: print(f" 样本数量: {results['flop']['total_samples']}") print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}") print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}") + print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}") if results['flop']['success']: passed_stages += 1 else: diff --git a/cross_validation/debug_emd.py b/cross_validation/debug_emd.py deleted file mode 100644 index 06612a6..0000000 --- a/cross_validation/debug_emd.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -""" -调试 -turn/flop阶段EMD在导出的数据与生成的数据间的差异 -""" - -import numpy as np -from cross_validation import DataValidator - -validator = DataValidator() - -print("解析Turn样本...") -turn_records = validator.parser.parse_turn_hist_with_cards(max_records=1) - -if turn_records: - record = turn_records[0] - player_cards = record.player_cards - board_cards = record.board_cards - src_hist = record.bins - - cur_hist = validator.generator.generate_turn_histogram( - player_cards, board_cards, num_bins=len(src_hist) - ) - - - src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist - - print(f"\n牌面: {[str(c) for c in player_cards]} + {[str(c) for c in board_cards]}") - print(f"bin数量: {len(src_hist)} vs {len(cur_hist)}") - print(f"原始直方图 - 和: {src_hist.sum():.3f}, 归一化后: {src_hist_norm.sum():.3f}") - print(f"生成直方图 - 和: {sum(cur_hist):.3f}") - - print("\n前10个bin对比:") - print("Bin 原始值 归一化 生成值") - for i in range(min(10, len(src_hist))): - print(f"{i:3d} {src_hist[i]:8.3f} {src_hist_norm[i]:8.3f} {cur_hist[i]:8.3f}") - - # 查看非零bin的分布 - src_nonzero = np.nonzero(src_hist_norm)[0] - cur_nonzero = np.nonzero(cur_hist)[0] - print(f"\n非零bins位置:") - print(f"原始: {src_nonzero[:10]}...") - print(f"生成: {cur_nonzero[:10]}...") - - # 计算分布的统计特征 - src_mean = np.average(range(len(src_hist_norm)), weights=src_hist_norm) - cur_mean = np.average(range(len(cur_hist)), weights=cur_hist) - - print(f"\n分布特征:") - print(f"原始分布重心: {src_mean:.2f}") - print(f"生成分布重心: {cur_mean:.2f}") - print(f"重心差异: {abs(src_mean - cur_mean):.2f}") \ No newline at end of file diff --git a/cross_validation/parse_data.py b/cross_validation/parse_data.py index 0d15107..7729117 100644 --- a/cross_validation/parse_data.py +++ b/cross_validation/parse_data.py @@ -1,4 +1,5 @@ import numpy as np +import random from typing import List, Dict, Tuple, Optional from dataclasses import dataclass from poker.card import Card, ShortDeckRank, Suit @@ -164,17 +165,19 @@ class XTaskDataParser: self.data_path = data_path self.decoder = OpenPQLDecoder() - def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]: + def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]: filepath = f"{self.data_path}/{filename}" try: raw_data = np.load(filepath) print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}") - + # 抽样 + data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data))) + records = [] decode_errors = 0 - - for i, row in enumerate(raw_data): + + for i, row in enumerate(data_to_process): try: board_id = int(row['board']) player_id = int(row['player']) @@ -245,10 +248,8 @@ class XTaskDataParser: records = [] decode_errors = 0 - - # 限制处理数量以避免内存问题 - # todo:抽样优化 - data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data + # 抽样 + data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data))) for i, row in enumerate(data_to_process): try: @@ -305,8 +306,7 @@ class XTaskDataParser: print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录") records = [] decode_errors = 0 - - data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data + data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data))) for i, row in enumerate(data_to_process): try: diff --git a/shortdeck/gen_hist.py b/shortdeck/gen_hist.py index e8a2574..179631a 100644 --- a/shortdeck/gen_hist.py +++ b/shortdeck/gen_hist.py @@ -147,97 +147,4 @@ class ShortDeckHistGenerator: print(f" Turn计算: generate nums{len(histogram)} ,num_bins={num_bins}") raise return histogram - - -# # 从一副牌中抽样验证 或者 从解析的数据中抽样验证 -# def generate_sample_data(self, num_samples: int = 3) -> Dict: -# print(f"\n 生成样本数据 (每阶段 {num_samples} 个样本)") - -# results = { -# 'river': [], -# 'turn': [], -# 'flop': [] -# } - -# for i in range(num_samples): -# # River样本(7张牌) -# river_cards = np.random.choice(self.full_deck, size=7, replace=False).tolist() -# player_cards = river_cards[:2] -# board_cards = river_cards[2:7] -# river_ehs = self.generate_river_ehs(player_cards, board_cards) -# results['river'].append({ -# 'player_cards': player_cards, -# 'board_cards': board_cards, -# 'ehs': river_ehs -# }) - -# # Turn样本(6张牌) -# turn_cards = np.random.choice(self.full_deck, size=6, replace=False).tolist() -# player_cards = turn_cards[:2] -# board_cards = turn_cards[2:6] -# turn_hist = self.generate_turn_histogram(player_cards, board_cards, num_bins=30) -# results['turn'].append({ -# 'player_cards': player_cards, -# 'board_cards': board_cards, -# 'histogram': turn_hist, -# 'mean': float(np.mean(turn_hist)), -# 'std': float(np.std(turn_hist)) -# }) - -# # Flop样本(5张牌) -# flop_cards = np.random.choice(self.full_deck, size=5, replace=False).tolist() -# player_cards = flop_cards[:2] -# board_cards = flop_cards[2:5] -# flop_hist = self.generate_flop_histogram(player_cards, board_cards, num_bins=465) -# results['flop'].append({ -# 'player_cards': [str(c) for c in player_cards], -# 'board_cards': [str(c) for c in board_cards], -# 'histogram_stats': flop_hist, -# 'mean': float(np.mean(flop_hist)), -# 'std': float(np.std(flop_hist)) -# }) - -# return results - - -# def main(): -# """测试短牌型EHS直方图生成器""" -# print("短牌型EHS直方图生成器测试") -# print("="*60) - -# generator = ShortDeckHistGenerator() - -# # 测试River EHS -# print(f"\n测试River EHS计算...") -# test_cards = np.random.choice(generator.full_deck, size=7, replace=False).tolist() -# player_cards = test_cards[:2] -# board_cards = test_cards[2:7] - -# player_str = " ".join(str(c) for c in player_cards) -# board_str = " ".join(str(c) for c in board_cards) - -# river_ehs = generator.generate_river_ehs(player_cards, board_cards) -# print(f" 玩家底牌: [{player_str}]") -# print(f" 公共牌: [{board_str}]") -# print(f" River EHS: {river_ehs:.4f}") - -# # 测试Turn直方图 -# print(f"\n测试Turn直方图生成...") -# turn_cards = np.random.choice(generator.full_deck, size=6, replace=False).tolist() -# player_cards = turn_cards[:2] -# board_cards = turn_cards[2:6] - -# turn_hist = generator.generate_turn_histogram(player_cards, board_cards, num_bins=30) - -# # 测试Flop直方图 -# print(f"\n测试Flop直方图生成...") -# flop_cards = np.random.choice(generator.full_deck, size=5, replace=False).tolist() -# player_cards = flop_cards[:2] -# board_cards = flop_cards[2:5] - -# flop_hist = generator.generate_flop_histogram(player_cards, board_cards, num_bins=465) -# print(f" Flop直方图: mean={np.mean(flop_hist):.3f}, 非零bins={np.count_nonzero(flop_hist)}") - -# # 生成样本数据 -# print(f"\n生成样本数据...") -# sample_data = generator.generate_sample_data(num_samples=3) + \ No newline at end of file