plt_hist
This commit is contained in:
@@ -6,6 +6,7 @@ import sys
|
||||
|
||||
from .parse_data import XTaskDataParser
|
||||
from shortdeck.gen_hist import ShortDeckHistGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class DataValidator:
|
||||
@@ -23,13 +24,12 @@ class DataValidator:
|
||||
try:
|
||||
print(" 解析导出的river数据...")
|
||||
print('='*60)
|
||||
river_records = self.parser.parse_river_ehs_with_cards()
|
||||
river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples)
|
||||
|
||||
if not river_records:
|
||||
return {'error': '没有解析到river记录', 'success': False}
|
||||
|
||||
|
||||
sample_records = np.random.choice(river_records, size=max_samples, replace=False)
|
||||
sample_records = river_records
|
||||
print(f" 选择 {len(sample_records)} 个样本进行验证")
|
||||
|
||||
matches = 0
|
||||
@@ -95,6 +95,21 @@ class DataValidator:
|
||||
except Exception as e:
|
||||
print(f" River验证失败: {e}")
|
||||
return {'error': str(e), 'success': False}
|
||||
|
||||
def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards):
|
||||
|
||||
player_str = " ".join(str(c) for c in player_cards)
|
||||
board_str = " ".join(str(c) for c in board_cards)
|
||||
print("="*60)
|
||||
print(f"样本 {i+1}: [{player_str}] + [{board_str}]")
|
||||
print("bin src src_norm cur cur_norm")
|
||||
src_hist = np.array(src_hist)
|
||||
cur_hist = np.array(cur_hist)
|
||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
||||
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
||||
for i in range(min(len(src_hist), 30)):
|
||||
if src_hist[i] > 0 or cur_hist[i] > 0:
|
||||
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
||||
|
||||
def validate_turn_samples(self, max_samples: int = 10):
|
||||
|
||||
@@ -140,20 +155,24 @@ class DataValidator:
|
||||
src_hist_norm,
|
||||
cur_hist_norm
|
||||
)
|
||||
|
||||
is_low_emd = emd_dist < 0.2
|
||||
if is_low_emd:
|
||||
low_emd_count += 1
|
||||
|
||||
# 显示详细信息(前3个样本)
|
||||
if i < 3:
|
||||
player_str = " ".join(str(c) for c in player_cards)
|
||||
board_str = " ".join(str(c) for c in board_cards)
|
||||
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
||||
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
||||
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
||||
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
||||
|
||||
|
||||
low_emd_count += 1 if emd_dist < 0.2 else 0
|
||||
errors += 1 if emd_dist >= 0.2 else 0
|
||||
emd_distances.append(emd_dist)
|
||||
|
||||
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
||||
|
||||
|
||||
|
||||
# 画图显示
|
||||
plt.plot(src_hist, label='src', marker='o')
|
||||
plt.plot(cur_hist, label='cur', marker='x')
|
||||
plt.title(f"turn_hist_emd={emd_dist:.6f}")
|
||||
plt.xlabel("Bins")
|
||||
plt.ylabel("Frequency")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors <= 3:
|
||||
@@ -228,11 +247,10 @@ class DataValidator:
|
||||
src_hist_norm,
|
||||
cur_hist_norm
|
||||
)
|
||||
# emd_distances.append(emd_dist)
|
||||
|
||||
# is_low_emd = emd_dist < 0.2 # EMD阈值
|
||||
# if is_low_emd:
|
||||
# low_emd_count += 1
|
||||
emd_distances.append(emd_dist)
|
||||
|
||||
low_emd_count += 1 if emd_dist < 10 else 0
|
||||
errors += 1 if emd_dist >= 10 else 0
|
||||
|
||||
# 显示详细信息
|
||||
player_str = " ".join(str(c) for c in player_cards)
|
||||
@@ -241,7 +259,22 @@ class DataValidator:
|
||||
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
||||
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
||||
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
||||
|
||||
print("bin src src_norm cur cur_norm")
|
||||
for i in range(min(len(src_hist), 30)):
|
||||
if src_hist[i] > 0 or cur_hist[i] > 0:
|
||||
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
||||
|
||||
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
||||
|
||||
# 画图显示
|
||||
plt.plot(src_hist, label='src', marker='o')
|
||||
plt.plot(cur_hist, label='cur', marker='x')
|
||||
plt.title(f"flop_hist_emd={emd_dist:.6f}")
|
||||
plt.xlabel("Bins")
|
||||
plt.ylabel("Frequency")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors <= 3:
|
||||
@@ -276,7 +309,6 @@ class DataValidator:
|
||||
|
||||
print(" 导出数据EHS验证")
|
||||
print("*"*60)
|
||||
print("验证策略: 从xtask导出数据中抽取牌面 → 短牌型生成器重计算 → 比较一致性")
|
||||
|
||||
# 执行各阶段验证
|
||||
results = {}
|
||||
@@ -312,6 +344,7 @@ class DataValidator:
|
||||
print(f" 样本数量: {results['turn']['total_samples']}")
|
||||
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
|
||||
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
|
||||
print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}")
|
||||
if results['turn']['success']:
|
||||
passed_stages += 1
|
||||
else:
|
||||
@@ -325,6 +358,7 @@ class DataValidator:
|
||||
print(f" 样本数量: {results['flop']['total_samples']}")
|
||||
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
|
||||
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
|
||||
print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}")
|
||||
if results['flop']['success']:
|
||||
passed_stages += 1
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user