This commit is contained in:
2025-09-28 16:44:00 +08:00
parent 0597239207
commit fc085eb77e
4 changed files with 69 additions and 180 deletions

View File

@@ -6,6 +6,7 @@ import sys
from .parse_data import XTaskDataParser from .parse_data import XTaskDataParser
from shortdeck.gen_hist import ShortDeckHistGenerator from shortdeck.gen_hist import ShortDeckHistGenerator
import matplotlib.pyplot as plt
class DataValidator: class DataValidator:
@@ -23,13 +24,12 @@ class DataValidator:
try: try:
print(" 解析导出的river数据...") print(" 解析导出的river数据...")
print('='*60) print('='*60)
river_records = self.parser.parse_river_ehs_with_cards() river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples)
if not river_records: if not river_records:
return {'error': '没有解析到river记录', 'success': False} return {'error': '没有解析到river记录', 'success': False}
sample_records = river_records
sample_records = np.random.choice(river_records, size=max_samples, replace=False)
print(f" 选择 {len(sample_records)} 个样本进行验证") print(f" 选择 {len(sample_records)} 个样本进行验证")
matches = 0 matches = 0
@@ -95,6 +95,21 @@ class DataValidator:
except Exception as e: except Exception as e:
print(f" River验证失败: {e}") print(f" River验证失败: {e}")
return {'error': str(e), 'success': False} return {'error': str(e), 'success': False}
def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards):
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print("="*60)
print(f"样本 {i+1}: [{player_str}] + [{board_str}]")
print("bin src src_norm cur cur_norm")
src_hist = np.array(src_hist)
cur_hist = np.array(cur_hist)
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
for i in range(min(len(src_hist), 30)):
if src_hist[i] > 0 or cur_hist[i] > 0:
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
def validate_turn_samples(self, max_samples: int = 10): def validate_turn_samples(self, max_samples: int = 10):
@@ -140,20 +155,24 @@ class DataValidator:
src_hist_norm, src_hist_norm,
cur_hist_norm cur_hist_norm
) )
is_low_emd = emd_dist < 0.2 low_emd_count += 1 if emd_dist < 0.2 else 0
if is_low_emd: errors += 1 if emd_dist >= 0.2 else 0
low_emd_count += 1 emd_distances.append(emd_dist)
# 显示详细信息前3个样本 self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
if i < 3:
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]") # 画图显示
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}") plt.plot(src_hist, label='src', marker='o')
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}") plt.plot(cur_hist, label='cur', marker='x')
print(f" 归一化后EMD距离: {emd_dist:.6f}") plt.title(f"turn_hist_emd={emd_dist:.6f}")
plt.xlabel("Bins")
plt.ylabel("Frequency")
plt.legend()
plt.show()
except Exception as e: except Exception as e:
errors += 1 errors += 1
if errors <= 3: if errors <= 3:
@@ -228,11 +247,10 @@ class DataValidator:
src_hist_norm, src_hist_norm,
cur_hist_norm cur_hist_norm
) )
# emd_distances.append(emd_dist) emd_distances.append(emd_dist)
# is_low_emd = emd_dist < 0.2 # EMD阈值 low_emd_count += 1 if emd_dist < 10 else 0
# if is_low_emd: errors += 1 if emd_dist >= 10 else 0
# low_emd_count += 1
# 显示详细信息 # 显示详细信息
player_str = " ".join(str(c) for c in player_cards) player_str = " ".join(str(c) for c in player_cards)
@@ -241,7 +259,22 @@ class DataValidator:
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}") print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}") print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
print(f" 归一化后EMD距离: {emd_dist:.6f}") print(f" 归一化后EMD距离: {emd_dist:.6f}")
print("bin src src_norm cur cur_norm")
for i in range(min(len(src_hist), 30)):
if src_hist[i] > 0 or cur_hist[i] > 0:
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
# 画图显示
plt.plot(src_hist, label='src', marker='o')
plt.plot(cur_hist, label='cur', marker='x')
plt.title(f"flop_hist_emd={emd_dist:.6f}")
plt.xlabel("Bins")
plt.ylabel("Frequency")
plt.legend()
plt.show()
except Exception as e: except Exception as e:
errors += 1 errors += 1
if errors <= 3: if errors <= 3:
@@ -276,7 +309,6 @@ class DataValidator:
print(" 导出数据EHS验证") print(" 导出数据EHS验证")
print("*"*60) print("*"*60)
print("验证策略: 从xtask导出数据中抽取牌面 → 短牌型生成器重计算 → 比较一致性")
# 执行各阶段验证 # 执行各阶段验证
results = {} results = {}
@@ -312,6 +344,7 @@ class DataValidator:
print(f" 样本数量: {results['turn']['total_samples']}") print(f" 样本数量: {results['turn']['total_samples']}")
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}") print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}") print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}")
if results['turn']['success']: if results['turn']['success']:
passed_stages += 1 passed_stages += 1
else: else:
@@ -325,6 +358,7 @@ class DataValidator:
print(f" 样本数量: {results['flop']['total_samples']}") print(f" 样本数量: {results['flop']['total_samples']}")
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}") print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}") print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}")
if results['flop']['success']: if results['flop']['success']:
passed_stages += 1 passed_stages += 1
else: else:

View File

@@ -1,52 +0,0 @@
#!/usr/bin/env python3
"""
调试
turn/flop阶段EMD在导出的数据与生成的数据间的差异
"""
import numpy as np
from cross_validation import DataValidator
validator = DataValidator()
print("解析Turn样本...")
turn_records = validator.parser.parse_turn_hist_with_cards(max_records=1)
if turn_records:
record = turn_records[0]
player_cards = record.player_cards
board_cards = record.board_cards
src_hist = record.bins
cur_hist = validator.generator.generate_turn_histogram(
player_cards, board_cards, num_bins=len(src_hist)
)
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
print(f"\n牌面: {[str(c) for c in player_cards]} + {[str(c) for c in board_cards]}")
print(f"bin数量: {len(src_hist)} vs {len(cur_hist)}")
print(f"原始直方图 - 和: {src_hist.sum():.3f}, 归一化后: {src_hist_norm.sum():.3f}")
print(f"生成直方图 - 和: {sum(cur_hist):.3f}")
print("\n前10个bin对比:")
print("Bin 原始值 归一化 生成值")
for i in range(min(10, len(src_hist))):
print(f"{i:3d} {src_hist[i]:8.3f} {src_hist_norm[i]:8.3f} {cur_hist[i]:8.3f}")
# 查看非零bin的分布
src_nonzero = np.nonzero(src_hist_norm)[0]
cur_nonzero = np.nonzero(cur_hist)[0]
print(f"\n非零bins位置:")
print(f"原始: {src_nonzero[:10]}...")
print(f"生成: {cur_nonzero[:10]}...")
# 计算分布的统计特征
src_mean = np.average(range(len(src_hist_norm)), weights=src_hist_norm)
cur_mean = np.average(range(len(cur_hist)), weights=cur_hist)
print(f"\n分布特征:")
print(f"原始分布重心: {src_mean:.2f}")
print(f"生成分布重心: {cur_mean:.2f}")
print(f"重心差异: {abs(src_mean - cur_mean):.2f}")

View File

@@ -1,4 +1,5 @@
import numpy as np import numpy as np
import random
from typing import List, Dict, Tuple, Optional from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass from dataclasses import dataclass
from poker.card import Card, ShortDeckRank, Suit from poker.card import Card, ShortDeckRank, Suit
@@ -164,17 +165,19 @@ class XTaskDataParser:
self.data_path = data_path self.data_path = data_path
self.decoder = OpenPQLDecoder() self.decoder = OpenPQLDecoder()
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]: def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]:
filepath = f"{self.data_path}/{filename}" filepath = f"{self.data_path}/{filename}"
try: try:
raw_data = np.load(filepath) raw_data = np.load(filepath)
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}") print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
# 抽样
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
records = [] records = []
decode_errors = 0 decode_errors = 0
for i, row in enumerate(raw_data): for i, row in enumerate(data_to_process):
try: try:
board_id = int(row['board']) board_id = int(row['board'])
player_id = int(row['player']) player_id = int(row['player'])
@@ -245,10 +248,8 @@ class XTaskDataParser:
records = [] records = []
decode_errors = 0 decode_errors = 0
# 抽样
# 限制处理数量以避免内存问题 data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
# todo:抽样优化
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
for i, row in enumerate(data_to_process): for i, row in enumerate(data_to_process):
try: try:
@@ -305,8 +306,7 @@ class XTaskDataParser:
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录") print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = [] records = []
decode_errors = 0 decode_errors = 0
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
for i, row in enumerate(data_to_process): for i, row in enumerate(data_to_process):
try: try:

View File

@@ -147,97 +147,4 @@ class ShortDeckHistGenerator:
print(f" Turn计算: generate nums{len(histogram)} ,num_bins={num_bins}") print(f" Turn计算: generate nums{len(histogram)} ,num_bins={num_bins}")
raise raise
return histogram return histogram
# # 从一副牌中抽样验证 或者 从解析的数据中抽样验证
# def generate_sample_data(self, num_samples: int = 3) -> Dict:
# print(f"\n 生成样本数据 (每阶段 {num_samples} 个样本)")
# results = {
# 'river': [],
# 'turn': [],
# 'flop': []
# }
# for i in range(num_samples):
# # River样本7张牌
# river_cards = np.random.choice(self.full_deck, size=7, replace=False).tolist()
# player_cards = river_cards[:2]
# board_cards = river_cards[2:7]
# river_ehs = self.generate_river_ehs(player_cards, board_cards)
# results['river'].append({
# 'player_cards': player_cards,
# 'board_cards': board_cards,
# 'ehs': river_ehs
# })
# # Turn样本6张牌
# turn_cards = np.random.choice(self.full_deck, size=6, replace=False).tolist()
# player_cards = turn_cards[:2]
# board_cards = turn_cards[2:6]
# turn_hist = self.generate_turn_histogram(player_cards, board_cards, num_bins=30)
# results['turn'].append({
# 'player_cards': player_cards,
# 'board_cards': board_cards,
# 'histogram': turn_hist,
# 'mean': float(np.mean(turn_hist)),
# 'std': float(np.std(turn_hist))
# })
# # Flop样本5张牌
# flop_cards = np.random.choice(self.full_deck, size=5, replace=False).tolist()
# player_cards = flop_cards[:2]
# board_cards = flop_cards[2:5]
# flop_hist = self.generate_flop_histogram(player_cards, board_cards, num_bins=465)
# results['flop'].append({
# 'player_cards': [str(c) for c in player_cards],
# 'board_cards': [str(c) for c in board_cards],
# 'histogram_stats': flop_hist,
# 'mean': float(np.mean(flop_hist)),
# 'std': float(np.std(flop_hist))
# })
# return results
# def main():
# """测试短牌型EHS直方图生成器"""
# print("短牌型EHS直方图生成器测试")
# print("="*60)
# generator = ShortDeckHistGenerator()
# # 测试River EHS
# print(f"\n测试River EHS计算...")
# test_cards = np.random.choice(generator.full_deck, size=7, replace=False).tolist()
# player_cards = test_cards[:2]
# board_cards = test_cards[2:7]
# player_str = " ".join(str(c) for c in player_cards)
# board_str = " ".join(str(c) for c in board_cards)
# river_ehs = generator.generate_river_ehs(player_cards, board_cards)
# print(f" 玩家底牌: [{player_str}]")
# print(f" 公共牌: [{board_str}]")
# print(f" River EHS: {river_ehs:.4f}")
# # 测试Turn直方图
# print(f"\n测试Turn直方图生成...")
# turn_cards = np.random.choice(generator.full_deck, size=6, replace=False).tolist()
# player_cards = turn_cards[:2]
# board_cards = turn_cards[2:6]
# turn_hist = generator.generate_turn_histogram(player_cards, board_cards, num_bins=30)
# # 测试Flop直方图
# print(f"\n测试Flop直方图生成...")
# flop_cards = np.random.choice(generator.full_deck, size=5, replace=False).tolist()
# player_cards = flop_cards[:2]
# board_cards = flop_cards[2:5]
# flop_hist = generator.generate_flop_histogram(player_cards, board_cards, num_bins=465)
# print(f" Flop直方图: mean={np.mean(flop_hist):.3f}, 非零bins={np.count_nonzero(flop_hist)}")
# # 生成样本数据
# print(f"\n生成样本数据...")
# sample_data = generator.generate_sample_data(num_samples=3)