392 lines
16 KiB
Python
392 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
import numpy as np
|
|
from typing import List, Dict, Tuple
|
|
from scipy.stats import wasserstein_distance
|
|
import sys
|
|
|
|
from .parse_data import XTaskDataParser
|
|
from shortdeck.gen_hist import ShortDeckHistGenerator
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class DataValidator:
|
|
|
|
|
|
def __init__(self, data_path: str = "ehs_data"):
|
|
self.parser = XTaskDataParser(data_path)
|
|
self.generator = ShortDeckHistGenerator()
|
|
print(" DataValidator初始化完成")
|
|
print(f" 生成器短牌型大小: {len(self.generator.full_deck)}")
|
|
|
|
def validate_river_samples(self, max_samples: int = 20) :
|
|
# print(f"\n 验证river_EHS样本 (最大样本数: {max_samples})")
|
|
|
|
try:
|
|
print(" 解析导出的river数据...")
|
|
print('='*60)
|
|
river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples)
|
|
|
|
if not river_records:
|
|
return {'error': '没有解析到river记录', 'success': False}
|
|
|
|
sample_records = river_records
|
|
print(f" 选择 {len(sample_records)} 个样本进行验证")
|
|
|
|
matches = 0
|
|
errors = 0
|
|
differences = []
|
|
|
|
for i, record in enumerate(sample_records):
|
|
try:
|
|
player_cards = record.player_cards
|
|
board_cards = record.board_cards
|
|
src_river_ehs = record.ehs
|
|
|
|
|
|
cur_river_ehs = self.generator.generate_river_ehs(player_cards, board_cards)
|
|
|
|
|
|
ehs_difference = abs(src_river_ehs - cur_river_ehs)
|
|
|
|
|
|
player_str = " ".join(str(c) for c in player_cards)
|
|
board_str = " ".join(str(c) for c in board_cards)
|
|
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
|
print(f" 原始EHS: {src_river_ehs:.6f}")
|
|
print(f" 重算EHS: {cur_river_ehs:.6f}")
|
|
print(f" 差异: {ehs_difference:.6f}")
|
|
|
|
# 判断匹配 (允许小的数值差异)
|
|
tolerance = 1e-6
|
|
if ehs_difference < tolerance:
|
|
matches += 1
|
|
else:
|
|
differences.append(ehs_difference)
|
|
|
|
except Exception as e:
|
|
errors += 1
|
|
if errors <= 3:
|
|
print(f" 样本 {i+1} 计算失败: {e}")
|
|
|
|
# 统计结果
|
|
total_samples = len(sample_records)
|
|
match_rate = matches / total_samples if total_samples > 0 else 0
|
|
mean_diff = np.mean(differences) if differences else 0
|
|
max_diff = np.max(differences) if differences else 0
|
|
|
|
result = {
|
|
'total_samples': total_samples,
|
|
'matches': matches,
|
|
'match_rate': match_rate,
|
|
'mean_difference': mean_diff,
|
|
'max_difference': max_diff,
|
|
'errors': errors,
|
|
'success': match_rate > 0.8 and mean_diff < 0.05
|
|
}
|
|
|
|
print(f" River验证完成:")
|
|
print('='*60)
|
|
print(f" 匹配数: {matches}/{total_samples} ({match_rate:.1%})")
|
|
print(f" 平均差异: {mean_diff:.6f}")
|
|
print(f" 最大差异: {max_diff:.6f}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f" River验证失败: {e}")
|
|
return {'error': str(e), 'success': False}
|
|
|
|
def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards):
|
|
|
|
player_str = " ".join(str(c) for c in player_cards)
|
|
board_str = " ".join(str(c) for c in board_cards)
|
|
print("="*60)
|
|
print(f"样本 {i+1}: [{player_str}] + [{board_str}]")
|
|
print("bin src src_norm cur cur_norm")
|
|
src_hist = np.array(src_hist)
|
|
cur_hist = np.array(cur_hist)
|
|
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
|
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
|
for i in range(min(len(src_hist), 30)):
|
|
if src_hist[i] > 0 or cur_hist[i] > 0:
|
|
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
|
|
|
def validate_turn_samples(self, max_samples: int = 10):
|
|
|
|
# print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})")
|
|
|
|
try:
|
|
|
|
print(" 解析导出的Turn数据...")
|
|
print('='*60)
|
|
print('='*60)
|
|
turn_records = self.parser.parse_turn_hist_with_cards(max_records=max_samples)
|
|
|
|
if not turn_records:
|
|
return {'error': '没有解析到Turn记录', 'success': False}
|
|
|
|
print(f" 解析到 {len(turn_records)} 个Turn样本")
|
|
|
|
|
|
low_emd_count = 0
|
|
emd_distances = []
|
|
errors = 0
|
|
|
|
for i, record in enumerate(turn_records):
|
|
try:
|
|
|
|
player_cards = record.player_cards
|
|
board_cards = record.board_cards
|
|
src_hist = np.array(record.bins)
|
|
|
|
cur_hist = self.generator.generate_turn_histogram(
|
|
player_cards, board_cards, num_bins=len(src_hist)
|
|
)
|
|
cur_hist = np.array(cur_hist)
|
|
|
|
# 归一化
|
|
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
|
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
|
|
|
# 计算EMD距离
|
|
emd_dist = wasserstein_distance(
|
|
range(len(src_hist_norm)),
|
|
range(len(cur_hist_norm)),
|
|
src_hist_norm,
|
|
cur_hist_norm
|
|
)
|
|
|
|
low_emd_count += 1 if emd_dist < 0.2 else 0
|
|
errors += 1 if emd_dist >= 0.2 else 0
|
|
emd_distances.append(emd_dist)
|
|
|
|
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
|
|
|
|
|
|
|
# 画图显示
|
|
plt.plot(src_hist, label='src', marker='o')
|
|
plt.plot(cur_hist, label='cur', marker='x')
|
|
plt.title(f"turn_hist_emd={emd_dist:.6f}")
|
|
plt.xlabel("Bins")
|
|
plt.ylabel("Frequency")
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
except Exception as e:
|
|
errors += 1
|
|
if errors <= 3:
|
|
print(f" 样本 {i+1} 计算失败: {e}")
|
|
|
|
# 统计结果
|
|
total_samples = len(turn_records)
|
|
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
|
|
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
|
|
|
|
result = {
|
|
'total_samples': total_samples,
|
|
'low_emd_count': low_emd_count,
|
|
'low_emd_rate': low_emd_rate,
|
|
'mean_emd_distance': mean_emd,
|
|
'emd_distances': emd_distances,
|
|
'errors': errors,
|
|
'success': low_emd_rate > 0.6 and mean_emd < 0.5
|
|
}
|
|
|
|
print(f" Turn验证完成:")
|
|
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
|
|
print(f" 平均EMD: {mean_emd:.6f}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f" Turn验证失败: {e}")
|
|
return {'error': str(e), 'success': False}
|
|
|
|
def validate_flop_samples(self, max_samples: int = 5):
|
|
# print(f"\n 验证Flop直方图样本 (最大样本数: {max_samples})")
|
|
|
|
try:
|
|
|
|
print(" 解析导出Flop数据...")
|
|
print('='*60)
|
|
print('='*60)
|
|
print('='*60)
|
|
flop_records = self.parser.parse_flop_hist_with_cards(max_records=max_samples)
|
|
|
|
if not flop_records:
|
|
return {'error': '没有解析到Flop记录', 'success': False}
|
|
|
|
print(f" 解析到 {len(flop_records)} 个Flop样本")
|
|
|
|
low_emd_count = 0
|
|
emd_distances = []
|
|
errors = 0
|
|
|
|
for i, record in enumerate(flop_records):
|
|
try:
|
|
print(f" 处理样本 {i+1}/{len(flop_records)}...")
|
|
|
|
player_cards = record.player_cards
|
|
board_cards = record.board_cards
|
|
src_hist = np.array(record.bins)
|
|
|
|
cur_hist = self.generator.generate_flop_histogram(
|
|
player_cards, board_cards, num_bins=len(src_hist)
|
|
)
|
|
cur_hist = np.array(cur_hist)
|
|
|
|
|
|
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
|
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
|
|
|
# 计算EMD距离
|
|
emd_dist = wasserstein_distance(
|
|
range(len(src_hist_norm)),
|
|
range(len(cur_hist_norm)),
|
|
src_hist_norm,
|
|
cur_hist_norm
|
|
)
|
|
emd_distances.append(emd_dist)
|
|
|
|
low_emd_count += 1 if emd_dist < 10 else 0
|
|
errors += 1 if emd_dist >= 10 else 0
|
|
|
|
# 显示详细信息
|
|
player_str = " ".join(str(c) for c in player_cards)
|
|
board_str = " ".join(str(c) for c in board_cards)
|
|
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
|
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
|
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
|
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
|
print("bin src src_norm cur cur_norm")
|
|
for i in range(min(len(src_hist), 30)):
|
|
if src_hist[i] > 0 or cur_hist[i] > 0:
|
|
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
|
|
|
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
|
|
|
# 画图显示
|
|
plt.plot(src_hist, label='src', marker='o')
|
|
plt.plot(cur_hist, label='cur', marker='x')
|
|
plt.title(f"flop_hist_emd={emd_dist:.6f}")
|
|
plt.xlabel("Bins")
|
|
plt.ylabel("Frequency")
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
except Exception as e:
|
|
errors += 1
|
|
if errors <= 3:
|
|
print(f" 样本 {i+1} 计算失败: {e}")
|
|
|
|
# 统计结果
|
|
total_samples = len(flop_records)
|
|
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
|
|
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
|
|
|
|
result = {
|
|
'total_samples': total_samples,
|
|
'low_emd_count': low_emd_count,
|
|
'low_emd_rate': low_emd_rate,
|
|
'mean_emd_distance': mean_emd,
|
|
'emd_distances': emd_distances,
|
|
'errors': errors,
|
|
'success': low_emd_rate > 0.4 and mean_emd < 0.5
|
|
}
|
|
|
|
print(f" Flop验证完成:")
|
|
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
|
|
print(f" 平均EMD: {mean_emd:.6f}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f" Flop验证失败: {e}")
|
|
return {'error': str(e), 'success': False}
|
|
|
|
def run_full_validation(self, river_samples: int = 20, turn_samples: int = 10, flop_samples: int = 5) -> Dict:
|
|
|
|
print(" 导出数据EHS验证")
|
|
print("*"*60)
|
|
|
|
# 执行各阶段验证
|
|
results = {}
|
|
results['river'] = self.validate_river_samples(river_samples)
|
|
results['turn'] = self.validate_turn_samples(turn_samples)
|
|
results['flop'] = self.validate_flop_samples(flop_samples)
|
|
|
|
print(f"\n{'='*60}")
|
|
print(" 验证完毕")
|
|
print(f"{'='*60}")
|
|
|
|
passed_stages = 0
|
|
total_stages = 3
|
|
|
|
# River结果
|
|
print(f"\n RIVER阶段:")
|
|
if 'error' not in results['river']:
|
|
status = " 通过" if results['river']['success'] else " 失败"
|
|
print(f" 验证结果: {status}")
|
|
print(f" 样本数量: {results['river']['total_samples']}")
|
|
print(f" 匹配率: {results['river']['match_rate']:.1%}")
|
|
print(f" 平均差异: {results['river']['mean_difference']:.6f}")
|
|
if results['river']['success']:
|
|
passed_stages += 1
|
|
else:
|
|
print(f" 错误: {results['river']['error']}")
|
|
|
|
# Turn结果
|
|
print(f"\n TURN阶段:")
|
|
if 'error' not in results['turn']:
|
|
status = " 通过" if results['turn']['success'] else " 失败"
|
|
print(f" 验证结果: {status}")
|
|
print(f" 样本数量: {results['turn']['total_samples']}")
|
|
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
|
|
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
|
|
print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}")
|
|
if results['turn']['success']:
|
|
passed_stages += 1
|
|
else:
|
|
print(f" 错误: {results['turn']['error']}")
|
|
|
|
# Flop结果
|
|
print(f"\n FLOP阶段:")
|
|
if 'error' not in results['flop']:
|
|
status = " 通过" if results['flop']['success'] else " 失败"
|
|
print(f" 验证结果: {status}")
|
|
print(f" 样本数量: {results['flop']['total_samples']}")
|
|
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
|
|
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
|
|
print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}")
|
|
if results['flop']['success']:
|
|
passed_stages += 1
|
|
else:
|
|
print(f" 错误: {results['flop']['error']}")
|
|
|
|
# 总体结果
|
|
passed_stages = 0
|
|
total_stages = 3
|
|
|
|
if results.get('river') and results['river'].get('success', False):
|
|
passed_stages += 1
|
|
if results.get('turn') and results['turn'].get('success', False):
|
|
passed_stages += 1
|
|
if results.get('flop') and results['flop'].get('success', False):
|
|
passed_stages += 1
|
|
|
|
overall_rate = passed_stages / total_stages
|
|
# print(f"\n 总体验证通过率: {passed_stages}/{total_stages} ({overall_rate:.1%})")
|
|
|
|
# if overall_rate >= 0.7:
|
|
# print(" 数据验证成功!短牌型生成器与解析数据基本一致。")
|
|
# else:
|
|
# print(" 验证存在问题,生成器可能与实际数据不匹配,需要调试。")
|
|
|
|
return {
|
|
'results': results,
|
|
'passed_stages': passed_stages,
|
|
'total_stages': total_stages,
|
|
'overall_success': overall_rate >= 0.7
|
|
}
|