task5
This commit is contained in:
52
cross_validation/debug_emd.py
Normal file
52
cross_validation/debug_emd.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
调试
|
||||
turn/flop阶段EMD在导出的数据与生成的数据间的差异
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from cross_validation import DataValidator
|
||||
|
||||
validator = DataValidator()
|
||||
|
||||
print("解析Turn样本...")
|
||||
turn_records = validator.parser.parse_turn_hist_with_cards(max_records=1)
|
||||
|
||||
if turn_records:
|
||||
record = turn_records[0]
|
||||
player_cards = record.player_cards
|
||||
board_cards = record.board_cards
|
||||
src_hist = record.bins
|
||||
|
||||
cur_hist = validator.generator.generate_turn_histogram(
|
||||
player_cards, board_cards, num_bins=len(src_hist)
|
||||
)
|
||||
|
||||
|
||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
||||
|
||||
print(f"\n牌面: {[str(c) for c in player_cards]} + {[str(c) for c in board_cards]}")
|
||||
print(f"bin数量: {len(src_hist)} vs {len(cur_hist)}")
|
||||
print(f"原始直方图 - 和: {src_hist.sum():.3f}, 归一化后: {src_hist_norm.sum():.3f}")
|
||||
print(f"生成直方图 - 和: {sum(cur_hist):.3f}")
|
||||
|
||||
print("\n前10个bin对比:")
|
||||
print("Bin 原始值 归一化 生成值")
|
||||
for i in range(min(10, len(src_hist))):
|
||||
print(f"{i:3d} {src_hist[i]:8.3f} {src_hist_norm[i]:8.3f} {cur_hist[i]:8.3f}")
|
||||
|
||||
# 查看非零bin的分布
|
||||
src_nonzero = np.nonzero(src_hist_norm)[0]
|
||||
cur_nonzero = np.nonzero(cur_hist)[0]
|
||||
print(f"\n非零bins位置:")
|
||||
print(f"原始: {src_nonzero[:10]}...")
|
||||
print(f"生成: {cur_nonzero[:10]}...")
|
||||
|
||||
# 计算分布的统计特征
|
||||
src_mean = np.average(range(len(src_hist_norm)), weights=src_hist_norm)
|
||||
cur_mean = np.average(range(len(cur_hist)), weights=cur_hist)
|
||||
|
||||
print(f"\n分布特征:")
|
||||
print(f"原始分布重心: {src_mean:.2f}")
|
||||
print(f"生成分布重心: {cur_mean:.2f}")
|
||||
print(f"重心差异: {abs(src_mean - cur_mean):.2f}")
|
||||
Reference in New Issue
Block a user