Files
poker_task1/cross_validation/cross_validation.py
2025-09-26 17:08:27 +08:00

358 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
import numpy as np
from typing import List, Dict, Tuple
from scipy.stats import wasserstein_distance
import sys
from .parse_data import XTaskDataParser
from shortdeck.gen_hist import ShortDeckHistGenerator
class DataValidator:
def __init__(self, data_path: str = "ehs_data"):
self.parser = XTaskDataParser(data_path)
self.generator = ShortDeckHistGenerator()
print(" DataValidator初始化完成")
print(f" 生成器短牌型大小: {len(self.generator.full_deck)}")
def validate_river_samples(self, max_samples: int = 20) :
# print(f"\n 验证river_EHS样本 (最大样本数: {max_samples})")
try:
print(" 解析导出的river数据...")
print('='*60)
river_records = self.parser.parse_river_ehs_with_cards()
if not river_records:
return {'error': '没有解析到river记录', 'success': False}
sample_records = np.random.choice(river_records, size=max_samples, replace=False)
print(f" 选择 {len(sample_records)} 个样本进行验证")
matches = 0
errors = 0
differences = []
for i, record in enumerate(sample_records):
try:
player_cards = record.player_cards
board_cards = record.board_cards
src_river_ehs = record.ehs
cur_river_ehs = self.generator.generate_river_ehs(player_cards, board_cards)
ehs_difference = abs(src_river_ehs - cur_river_ehs)
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
print(f" 原始EHS: {src_river_ehs:.6f}")
print(f" 重算EHS: {cur_river_ehs:.6f}")
print(f" 差异: {ehs_difference:.6f}")
# 判断匹配 (允许小的数值差异)
tolerance = 1e-6
if ehs_difference < tolerance:
matches += 1
else:
differences.append(ehs_difference)
except Exception as e:
errors += 1
if errors <= 3:
print(f" 样本 {i+1} 计算失败: {e}")
# 统计结果
total_samples = len(sample_records)
match_rate = matches / total_samples if total_samples > 0 else 0
mean_diff = np.mean(differences) if differences else 0
max_diff = np.max(differences) if differences else 0
result = {
'total_samples': total_samples,
'matches': matches,
'match_rate': match_rate,
'mean_difference': mean_diff,
'max_difference': max_diff,
'errors': errors,
'success': match_rate > 0.8 and mean_diff < 0.05
}
print(f" River验证完成:")
print('='*60)
print(f" 匹配数: {matches}/{total_samples} ({match_rate:.1%})")
print(f" 平均差异: {mean_diff:.6f}")
print(f" 最大差异: {max_diff:.6f}")
return result
except Exception as e:
print(f" River验证失败: {e}")
return {'error': str(e), 'success': False}
def validate_turn_samples(self, max_samples: int = 10):
# print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})")
try:
print(" 解析导出的Turn数据...")
print('='*60)
print('='*60)
turn_records = self.parser.parse_turn_hist_with_cards(max_records=max_samples)
if not turn_records:
return {'error': '没有解析到Turn记录', 'success': False}
print(f" 解析到 {len(turn_records)} 个Turn样本")
low_emd_count = 0
emd_distances = []
errors = 0
for i, record in enumerate(turn_records):
try:
player_cards = record.player_cards
board_cards = record.board_cards
src_hist = np.array(record.bins)
cur_hist = self.generator.generate_turn_histogram(
player_cards, board_cards, num_bins=len(src_hist)
)
cur_hist = np.array(cur_hist)
# 归一化
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
# 计算EMD距离
emd_dist = wasserstein_distance(
range(len(src_hist_norm)),
range(len(cur_hist_norm)),
src_hist_norm,
cur_hist_norm
)
is_low_emd = emd_dist < 0.2
if is_low_emd:
low_emd_count += 1
# 显示详细信息前3个样本
if i < 3:
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
print(f" 归一化后EMD距离: {emd_dist:.6f}")
except Exception as e:
errors += 1
if errors <= 3:
print(f" 样本 {i+1} 计算失败: {e}")
# 统计结果
total_samples = len(turn_records)
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
result = {
'total_samples': total_samples,
'low_emd_count': low_emd_count,
'low_emd_rate': low_emd_rate,
'mean_emd_distance': mean_emd,
'emd_distances': emd_distances,
'errors': errors,
'success': low_emd_rate > 0.6 and mean_emd < 0.5
}
print(f" Turn验证完成:")
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
print(f" 平均EMD: {mean_emd:.6f}")
return result
except Exception as e:
print(f" Turn验证失败: {e}")
return {'error': str(e), 'success': False}
def validate_flop_samples(self, max_samples: int = 5):
# print(f"\n 验证Flop直方图样本 (最大样本数: {max_samples})")
try:
print(" 解析导出Flop数据...")
print('='*60)
print('='*60)
print('='*60)
flop_records = self.parser.parse_flop_hist_with_cards(max_records=max_samples)
if not flop_records:
return {'error': '没有解析到Flop记录', 'success': False}
print(f" 解析到 {len(flop_records)} 个Flop样本")
low_emd_count = 0
emd_distances = []
errors = 0
for i, record in enumerate(flop_records):
try:
print(f" 处理样本 {i+1}/{len(flop_records)}...")
player_cards = record.player_cards
board_cards = record.board_cards
src_hist = np.array(record.bins)
cur_hist = self.generator.generate_flop_histogram(
player_cards, board_cards, num_bins=len(src_hist)
)
cur_hist = np.array(cur_hist)
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
# 计算EMD距离
emd_dist = wasserstein_distance(
range(len(src_hist_norm)),
range(len(cur_hist_norm)),
src_hist_norm,
cur_hist_norm
)
# emd_distances.append(emd_dist)
# is_low_emd = emd_dist < 0.2 # EMD阈值
# if is_low_emd:
# low_emd_count += 1
# 显示详细信息
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
print(f" 归一化后EMD距离: {emd_dist:.6f}")
except Exception as e:
errors += 1
if errors <= 3:
print(f" 样本 {i+1} 计算失败: {e}")
# 统计结果
total_samples = len(flop_records)
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
result = {
'total_samples': total_samples,
'low_emd_count': low_emd_count,
'low_emd_rate': low_emd_rate,
'mean_emd_distance': mean_emd,
'emd_distances': emd_distances,
'errors': errors,
'success': low_emd_rate > 0.4 and mean_emd < 0.5
}
print(f" Flop验证完成:")
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
print(f" 平均EMD: {mean_emd:.6f}")
return result
except Exception as e:
print(f" Flop验证失败: {e}")
return {'error': str(e), 'success': False}
def run_full_validation(self, river_samples: int = 20, turn_samples: int = 10, flop_samples: int = 5) -> Dict:
print(" 导出数据EHS验证")
print("*"*60)
print("验证策略: 从xtask导出数据中抽取牌面 → 短牌型生成器重计算 → 比较一致性")
# 执行各阶段验证
results = {}
results['river'] = self.validate_river_samples(river_samples)
results['turn'] = self.validate_turn_samples(turn_samples)
results['flop'] = self.validate_flop_samples(flop_samples)
print(f"\n{'='*60}")
print(" 验证完毕")
print(f"{'='*60}")
passed_stages = 0
total_stages = 3
# River结果
print(f"\n RIVER阶段:")
if 'error' not in results['river']:
status = " 通过" if results['river']['success'] else " 失败"
print(f" 验证结果: {status}")
print(f" 样本数量: {results['river']['total_samples']}")
print(f" 匹配率: {results['river']['match_rate']:.1%}")
print(f" 平均差异: {results['river']['mean_difference']:.6f}")
if results['river']['success']:
passed_stages += 1
else:
print(f" 错误: {results['river']['error']}")
# Turn结果
print(f"\n TURN阶段:")
if 'error' not in results['turn']:
status = " 通过" if results['turn']['success'] else " 失败"
print(f" 验证结果: {status}")
print(f" 样本数量: {results['turn']['total_samples']}")
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
if results['turn']['success']:
passed_stages += 1
else:
print(f" 错误: {results['turn']['error']}")
# Flop结果
print(f"\n FLOP阶段:")
if 'error' not in results['flop']:
status = " 通过" if results['flop']['success'] else " 失败"
print(f" 验证结果: {status}")
print(f" 样本数量: {results['flop']['total_samples']}")
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
if results['flop']['success']:
passed_stages += 1
else:
print(f" 错误: {results['flop']['error']}")
# 总体结果
passed_stages = 0
total_stages = 3
if results.get('river') and results['river'].get('success', False):
passed_stages += 1
if results.get('turn') and results['turn'].get('success', False):
passed_stages += 1
if results.get('flop') and results['flop'].get('success', False):
passed_stages += 1
overall_rate = passed_stages / total_stages
# print(f"\n 总体验证通过率: {passed_stages}/{total_stages} ({overall_rate:.1%})")
# if overall_rate >= 0.7:
# print(" 数据验证成功!短牌型生成器与解析数据基本一致。")
# else:
# print(" 验证存在问题,生成器可能与实际数据不匹配,需要调试。")
return {
'results': results,
'passed_stages': passed_stages,
'total_stages': total_stages,
'overall_success': overall_rate >= 0.7
}