Files
poker_task1/cross_validation/debug_emd.py
2025-09-26 17:08:27 +08:00

52 lines
1.8 KiB
Python

#!/usr/bin/env python3
"""
调试
turn/flop阶段EMD在导出的数据与生成的数据间的差异
"""
import numpy as np
from cross_validation import DataValidator
validator = DataValidator()
print("解析Turn样本...")
turn_records = validator.parser.parse_turn_hist_with_cards(max_records=1)
if turn_records:
record = turn_records[0]
player_cards = record.player_cards
board_cards = record.board_cards
src_hist = record.bins
cur_hist = validator.generator.generate_turn_histogram(
player_cards, board_cards, num_bins=len(src_hist)
)
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
print(f"\n牌面: {[str(c) for c in player_cards]} + {[str(c) for c in board_cards]}")
print(f"bin数量: {len(src_hist)} vs {len(cur_hist)}")
print(f"原始直方图 - 和: {src_hist.sum():.3f}, 归一化后: {src_hist_norm.sum():.3f}")
print(f"生成直方图 - 和: {sum(cur_hist):.3f}")
print("\n前10个bin对比:")
print("Bin 原始值 归一化 生成值")
for i in range(min(10, len(src_hist))):
print(f"{i:3d} {src_hist[i]:8.3f} {src_hist_norm[i]:8.3f} {cur_hist[i]:8.3f}")
# 查看非零bin的分布
src_nonzero = np.nonzero(src_hist_norm)[0]
cur_nonzero = np.nonzero(cur_hist)[0]
print(f"\n非零bins位置:")
print(f"原始: {src_nonzero[:10]}...")
print(f"生成: {cur_nonzero[:10]}...")
# 计算分布的统计特征
src_mean = np.average(range(len(src_hist_norm)), weights=src_hist_norm)
cur_mean = np.average(range(len(cur_hist)), weights=cur_hist)
print(f"\n分布特征:")
print(f"原始分布重心: {src_mean:.2f}")
print(f"生成分布重心: {cur_mean:.2f}")
print(f"重心差异: {abs(src_mean - cur_mean):.2f}")