#!/usr/bin/env python3 import numpy as np from typing import List, Dict, Tuple from scipy.stats import wasserstein_distance import sys from .parse_data import XTaskDataParser from shortdeck.gen_hist import ShortDeckHistGenerator class DataValidator: def __init__(self, data_path: str = "ehs_data"): self.parser = XTaskDataParser(data_path) self.generator = ShortDeckHistGenerator() print(" DataValidator初始化完成") print(f" 生成器短牌型大小: {len(self.generator.full_deck)}") def validate_river_samples(self, max_samples: int = 20) : # print(f"\n 验证river_EHS样本 (最大样本数: {max_samples})") try: print(" 解析导出的river数据...") print('='*60) river_records = self.parser.parse_river_ehs_with_cards() if not river_records: return {'error': '没有解析到river记录', 'success': False} sample_records = np.random.choice(river_records, size=max_samples, replace=False) print(f" 选择 {len(sample_records)} 个样本进行验证") matches = 0 errors = 0 differences = [] for i, record in enumerate(sample_records): try: player_cards = record.player_cards board_cards = record.board_cards src_river_ehs = record.ehs cur_river_ehs = self.generator.generate_river_ehs(player_cards, board_cards) ehs_difference = abs(src_river_ehs - cur_river_ehs) player_str = " ".join(str(c) for c in player_cards) board_str = " ".join(str(c) for c in board_cards) print(f" 样本 {i+1}: [{player_str}] + [{board_str}]") print(f" 原始EHS: {src_river_ehs:.6f}") print(f" 重算EHS: {cur_river_ehs:.6f}") print(f" 差异: {ehs_difference:.6f}") # 判断匹配 (允许小的数值差异) tolerance = 1e-6 if ehs_difference < tolerance: matches += 1 else: differences.append(ehs_difference) except Exception as e: errors += 1 if errors <= 3: print(f" 样本 {i+1} 计算失败: {e}") # 统计结果 total_samples = len(sample_records) match_rate = matches / total_samples if total_samples > 0 else 0 mean_diff = np.mean(differences) if differences else 0 max_diff = np.max(differences) if differences else 0 result = { 'total_samples': total_samples, 'matches': matches, 'match_rate': match_rate, 'mean_difference': mean_diff, 'max_difference': max_diff, 'errors': errors, 'success': match_rate > 0.8 and mean_diff < 0.05 } print(f" River验证完成:") print('='*60) print(f" 匹配数: {matches}/{total_samples} ({match_rate:.1%})") print(f" 平均差异: {mean_diff:.6f}") print(f" 最大差异: {max_diff:.6f}") return result except Exception as e: print(f" River验证失败: {e}") return {'error': str(e), 'success': False} def validate_turn_samples(self, max_samples: int = 10): # print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})") try: print(" 解析导出的Turn数据...") print('='*60) print('='*60) turn_records = self.parser.parse_turn_hist_with_cards(max_records=max_samples) if not turn_records: return {'error': '没有解析到Turn记录', 'success': False} print(f" 解析到 {len(turn_records)} 个Turn样本") low_emd_count = 0 emd_distances = [] errors = 0 for i, record in enumerate(turn_records): try: player_cards = record.player_cards board_cards = record.board_cards src_hist = np.array(record.bins) cur_hist = self.generator.generate_turn_histogram( player_cards, board_cards, num_bins=len(src_hist) ) cur_hist = np.array(cur_hist) # 归一化 src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist # 计算EMD距离 emd_dist = wasserstein_distance( range(len(src_hist_norm)), range(len(cur_hist_norm)), src_hist_norm, cur_hist_norm ) is_low_emd = emd_dist < 0.2 if is_low_emd: low_emd_count += 1 # 显示详细信息(前3个样本) if i < 3: player_str = " ".join(str(c) for c in player_cards) board_str = " ".join(str(c) for c in board_cards) print(f" 样本 {i+1}: [{player_str}] + [{board_str}]") print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}") print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}") print(f" 归一化后EMD距离: {emd_dist:.6f}") except Exception as e: errors += 1 if errors <= 3: print(f" 样本 {i+1} 计算失败: {e}") # 统计结果 total_samples = len(turn_records) low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0 mean_emd = np.mean(emd_distances) if emd_distances else float('inf') result = { 'total_samples': total_samples, 'low_emd_count': low_emd_count, 'low_emd_rate': low_emd_rate, 'mean_emd_distance': mean_emd, 'emd_distances': emd_distances, 'errors': errors, 'success': low_emd_rate > 0.6 and mean_emd < 0.5 } print(f" Turn验证完成:") print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})") print(f" 平均EMD: {mean_emd:.6f}") return result except Exception as e: print(f" Turn验证失败: {e}") return {'error': str(e), 'success': False} def validate_flop_samples(self, max_samples: int = 5): # print(f"\n 验证Flop直方图样本 (最大样本数: {max_samples})") try: print(" 解析导出Flop数据...") print('='*60) print('='*60) print('='*60) flop_records = self.parser.parse_flop_hist_with_cards(max_records=max_samples) if not flop_records: return {'error': '没有解析到Flop记录', 'success': False} print(f" 解析到 {len(flop_records)} 个Flop样本") low_emd_count = 0 emd_distances = [] errors = 0 for i, record in enumerate(flop_records): try: print(f" 处理样本 {i+1}/{len(flop_records)}...") player_cards = record.player_cards board_cards = record.board_cards src_hist = np.array(record.bins) cur_hist = self.generator.generate_flop_histogram( player_cards, board_cards, num_bins=len(src_hist) ) cur_hist = np.array(cur_hist) src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist # 计算EMD距离 emd_dist = wasserstein_distance( range(len(src_hist_norm)), range(len(cur_hist_norm)), src_hist_norm, cur_hist_norm ) # emd_distances.append(emd_dist) # is_low_emd = emd_dist < 0.2 # EMD阈值 # if is_low_emd: # low_emd_count += 1 # 显示详细信息 player_str = " ".join(str(c) for c in player_cards) board_str = " ".join(str(c) for c in board_cards) print(f" 样本 {i+1}: [{player_str}] + [{board_str}]") print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}") print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}") print(f" 归一化后EMD距离: {emd_dist:.6f}") except Exception as e: errors += 1 if errors <= 3: print(f" 样本 {i+1} 计算失败: {e}") # 统计结果 total_samples = len(flop_records) low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0 mean_emd = np.mean(emd_distances) if emd_distances else float('inf') result = { 'total_samples': total_samples, 'low_emd_count': low_emd_count, 'low_emd_rate': low_emd_rate, 'mean_emd_distance': mean_emd, 'emd_distances': emd_distances, 'errors': errors, 'success': low_emd_rate > 0.4 and mean_emd < 0.5 } print(f" Flop验证完成:") print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})") print(f" 平均EMD: {mean_emd:.6f}") return result except Exception as e: print(f" Flop验证失败: {e}") return {'error': str(e), 'success': False} def run_full_validation(self, river_samples: int = 20, turn_samples: int = 10, flop_samples: int = 5) -> Dict: print(" 导出数据EHS验证") print("*"*60) print("验证策略: 从xtask导出数据中抽取牌面 → 短牌型生成器重计算 → 比较一致性") # 执行各阶段验证 results = {} results['river'] = self.validate_river_samples(river_samples) results['turn'] = self.validate_turn_samples(turn_samples) results['flop'] = self.validate_flop_samples(flop_samples) print(f"\n{'='*60}") print(" 验证完毕") print(f"{'='*60}") passed_stages = 0 total_stages = 3 # River结果 print(f"\n RIVER阶段:") if 'error' not in results['river']: status = " 通过" if results['river']['success'] else " 失败" print(f" 验证结果: {status}") print(f" 样本数量: {results['river']['total_samples']}") print(f" 匹配率: {results['river']['match_rate']:.1%}") print(f" 平均差异: {results['river']['mean_difference']:.6f}") if results['river']['success']: passed_stages += 1 else: print(f" 错误: {results['river']['error']}") # Turn结果 print(f"\n TURN阶段:") if 'error' not in results['turn']: status = " 通过" if results['turn']['success'] else " 失败" print(f" 验证结果: {status}") print(f" 样本数量: {results['turn']['total_samples']}") print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}") print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}") if results['turn']['success']: passed_stages += 1 else: print(f" 错误: {results['turn']['error']}") # Flop结果 print(f"\n FLOP阶段:") if 'error' not in results['flop']: status = " 通过" if results['flop']['success'] else " 失败" print(f" 验证结果: {status}") print(f" 样本数量: {results['flop']['total_samples']}") print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}") print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}") if results['flop']['success']: passed_stages += 1 else: print(f" 错误: {results['flop']['error']}") # 总体结果 passed_stages = 0 total_stages = 3 if results.get('river') and results['river'].get('success', False): passed_stages += 1 if results.get('turn') and results['turn'].get('success', False): passed_stages += 1 if results.get('flop') and results['flop'].get('success', False): passed_stages += 1 overall_rate = passed_stages / total_stages # print(f"\n 总体验证通过率: {passed_stages}/{total_stages} ({overall_rate:.1%})") # if overall_rate >= 0.7: # print(" 数据验证成功!短牌型生成器与解析数据基本一致。") # else: # print(" 验证存在问题,生成器可能与实际数据不匹配,需要调试。") return { 'results': results, 'passed_stages': passed_stages, 'total_stages': total_stages, 'overall_success': overall_rate >= 0.7 }