plt_hist
This commit is contained in:
@@ -6,6 +6,7 @@ import sys
|
|||||||
|
|
||||||
from .parse_data import XTaskDataParser
|
from .parse_data import XTaskDataParser
|
||||||
from shortdeck.gen_hist import ShortDeckHistGenerator
|
from shortdeck.gen_hist import ShortDeckHistGenerator
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
class DataValidator:
|
class DataValidator:
|
||||||
@@ -23,13 +24,12 @@ class DataValidator:
|
|||||||
try:
|
try:
|
||||||
print(" 解析导出的river数据...")
|
print(" 解析导出的river数据...")
|
||||||
print('='*60)
|
print('='*60)
|
||||||
river_records = self.parser.parse_river_ehs_with_cards()
|
river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples)
|
||||||
|
|
||||||
if not river_records:
|
if not river_records:
|
||||||
return {'error': '没有解析到river记录', 'success': False}
|
return {'error': '没有解析到river记录', 'success': False}
|
||||||
|
|
||||||
|
sample_records = river_records
|
||||||
sample_records = np.random.choice(river_records, size=max_samples, replace=False)
|
|
||||||
print(f" 选择 {len(sample_records)} 个样本进行验证")
|
print(f" 选择 {len(sample_records)} 个样本进行验证")
|
||||||
|
|
||||||
matches = 0
|
matches = 0
|
||||||
@@ -96,6 +96,21 @@ class DataValidator:
|
|||||||
print(f" River验证失败: {e}")
|
print(f" River验证失败: {e}")
|
||||||
return {'error': str(e), 'success': False}
|
return {'error': str(e), 'success': False}
|
||||||
|
|
||||||
|
def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards):
|
||||||
|
|
||||||
|
player_str = " ".join(str(c) for c in player_cards)
|
||||||
|
board_str = " ".join(str(c) for c in board_cards)
|
||||||
|
print("="*60)
|
||||||
|
print(f"样本 {i+1}: [{player_str}] + [{board_str}]")
|
||||||
|
print("bin src src_norm cur cur_norm")
|
||||||
|
src_hist = np.array(src_hist)
|
||||||
|
cur_hist = np.array(cur_hist)
|
||||||
|
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
||||||
|
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
||||||
|
for i in range(min(len(src_hist), 30)):
|
||||||
|
if src_hist[i] > 0 or cur_hist[i] > 0:
|
||||||
|
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
||||||
|
|
||||||
def validate_turn_samples(self, max_samples: int = 10):
|
def validate_turn_samples(self, max_samples: int = 10):
|
||||||
|
|
||||||
# print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})")
|
# print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})")
|
||||||
@@ -141,18 +156,22 @@ class DataValidator:
|
|||||||
cur_hist_norm
|
cur_hist_norm
|
||||||
)
|
)
|
||||||
|
|
||||||
is_low_emd = emd_dist < 0.2
|
low_emd_count += 1 if emd_dist < 0.2 else 0
|
||||||
if is_low_emd:
|
errors += 1 if emd_dist >= 0.2 else 0
|
||||||
low_emd_count += 1
|
emd_distances.append(emd_dist)
|
||||||
|
|
||||||
# 显示详细信息(前3个样本)
|
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
||||||
if i < 3:
|
|
||||||
player_str = " ".join(str(c) for c in player_cards)
|
|
||||||
board_str = " ".join(str(c) for c in board_cards)
|
|
||||||
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
# 画图显示
|
||||||
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
plt.plot(src_hist, label='src', marker='o')
|
||||||
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
plt.plot(cur_hist, label='cur', marker='x')
|
||||||
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
plt.title(f"turn_hist_emd={emd_dist:.6f}")
|
||||||
|
plt.xlabel("Bins")
|
||||||
|
plt.ylabel("Frequency")
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors += 1
|
errors += 1
|
||||||
@@ -228,11 +247,10 @@ class DataValidator:
|
|||||||
src_hist_norm,
|
src_hist_norm,
|
||||||
cur_hist_norm
|
cur_hist_norm
|
||||||
)
|
)
|
||||||
# emd_distances.append(emd_dist)
|
emd_distances.append(emd_dist)
|
||||||
|
|
||||||
# is_low_emd = emd_dist < 0.2 # EMD阈值
|
low_emd_count += 1 if emd_dist < 10 else 0
|
||||||
# if is_low_emd:
|
errors += 1 if emd_dist >= 10 else 0
|
||||||
# low_emd_count += 1
|
|
||||||
|
|
||||||
# 显示详细信息
|
# 显示详细信息
|
||||||
player_str = " ".join(str(c) for c in player_cards)
|
player_str = " ".join(str(c) for c in player_cards)
|
||||||
@@ -241,6 +259,21 @@ class DataValidator:
|
|||||||
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
||||||
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
||||||
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
||||||
|
print("bin src src_norm cur cur_norm")
|
||||||
|
for i in range(min(len(src_hist), 30)):
|
||||||
|
if src_hist[i] > 0 or cur_hist[i] > 0:
|
||||||
|
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
||||||
|
|
||||||
|
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
||||||
|
|
||||||
|
# 画图显示
|
||||||
|
plt.plot(src_hist, label='src', marker='o')
|
||||||
|
plt.plot(cur_hist, label='cur', marker='x')
|
||||||
|
plt.title(f"flop_hist_emd={emd_dist:.6f}")
|
||||||
|
plt.xlabel("Bins")
|
||||||
|
plt.ylabel("Frequency")
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors += 1
|
errors += 1
|
||||||
@@ -276,7 +309,6 @@ class DataValidator:
|
|||||||
|
|
||||||
print(" 导出数据EHS验证")
|
print(" 导出数据EHS验证")
|
||||||
print("*"*60)
|
print("*"*60)
|
||||||
print("验证策略: 从xtask导出数据中抽取牌面 → 短牌型生成器重计算 → 比较一致性")
|
|
||||||
|
|
||||||
# 执行各阶段验证
|
# 执行各阶段验证
|
||||||
results = {}
|
results = {}
|
||||||
@@ -312,6 +344,7 @@ class DataValidator:
|
|||||||
print(f" 样本数量: {results['turn']['total_samples']}")
|
print(f" 样本数量: {results['turn']['total_samples']}")
|
||||||
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
|
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
|
||||||
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
|
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
|
||||||
|
print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}")
|
||||||
if results['turn']['success']:
|
if results['turn']['success']:
|
||||||
passed_stages += 1
|
passed_stages += 1
|
||||||
else:
|
else:
|
||||||
@@ -325,6 +358,7 @@ class DataValidator:
|
|||||||
print(f" 样本数量: {results['flop']['total_samples']}")
|
print(f" 样本数量: {results['flop']['total_samples']}")
|
||||||
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
|
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
|
||||||
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
|
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
|
||||||
|
print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}")
|
||||||
if results['flop']['success']:
|
if results['flop']['success']:
|
||||||
passed_stages += 1
|
passed_stages += 1
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,52 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
调试
|
|
||||||
turn/flop阶段EMD在导出的数据与生成的数据间的差异
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from cross_validation import DataValidator
|
|
||||||
|
|
||||||
validator = DataValidator()
|
|
||||||
|
|
||||||
print("解析Turn样本...")
|
|
||||||
turn_records = validator.parser.parse_turn_hist_with_cards(max_records=1)
|
|
||||||
|
|
||||||
if turn_records:
|
|
||||||
record = turn_records[0]
|
|
||||||
player_cards = record.player_cards
|
|
||||||
board_cards = record.board_cards
|
|
||||||
src_hist = record.bins
|
|
||||||
|
|
||||||
cur_hist = validator.generator.generate_turn_histogram(
|
|
||||||
player_cards, board_cards, num_bins=len(src_hist)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
|
||||||
|
|
||||||
print(f"\n牌面: {[str(c) for c in player_cards]} + {[str(c) for c in board_cards]}")
|
|
||||||
print(f"bin数量: {len(src_hist)} vs {len(cur_hist)}")
|
|
||||||
print(f"原始直方图 - 和: {src_hist.sum():.3f}, 归一化后: {src_hist_norm.sum():.3f}")
|
|
||||||
print(f"生成直方图 - 和: {sum(cur_hist):.3f}")
|
|
||||||
|
|
||||||
print("\n前10个bin对比:")
|
|
||||||
print("Bin 原始值 归一化 生成值")
|
|
||||||
for i in range(min(10, len(src_hist))):
|
|
||||||
print(f"{i:3d} {src_hist[i]:8.3f} {src_hist_norm[i]:8.3f} {cur_hist[i]:8.3f}")
|
|
||||||
|
|
||||||
# 查看非零bin的分布
|
|
||||||
src_nonzero = np.nonzero(src_hist_norm)[0]
|
|
||||||
cur_nonzero = np.nonzero(cur_hist)[0]
|
|
||||||
print(f"\n非零bins位置:")
|
|
||||||
print(f"原始: {src_nonzero[:10]}...")
|
|
||||||
print(f"生成: {cur_nonzero[:10]}...")
|
|
||||||
|
|
||||||
# 计算分布的统计特征
|
|
||||||
src_mean = np.average(range(len(src_hist_norm)), weights=src_hist_norm)
|
|
||||||
cur_mean = np.average(range(len(cur_hist)), weights=cur_hist)
|
|
||||||
|
|
||||||
print(f"\n分布特征:")
|
|
||||||
print(f"原始分布重心: {src_mean:.2f}")
|
|
||||||
print(f"生成分布重心: {cur_mean:.2f}")
|
|
||||||
print(f"重心差异: {abs(src_mean - cur_mean):.2f}")
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import random
|
||||||
from typing import List, Dict, Tuple, Optional
|
from typing import List, Dict, Tuple, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from poker.card import Card, ShortDeckRank, Suit
|
from poker.card import Card, ShortDeckRank, Suit
|
||||||
@@ -164,17 +165,19 @@ class XTaskDataParser:
|
|||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
self.decoder = OpenPQLDecoder()
|
self.decoder = OpenPQLDecoder()
|
||||||
|
|
||||||
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]:
|
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]:
|
||||||
filepath = f"{self.data_path}/{filename}"
|
filepath = f"{self.data_path}/{filename}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_data = np.load(filepath)
|
raw_data = np.load(filepath)
|
||||||
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
|
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
|
||||||
|
# 抽样
|
||||||
|
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
||||||
|
|
||||||
records = []
|
records = []
|
||||||
decode_errors = 0
|
decode_errors = 0
|
||||||
|
|
||||||
for i, row in enumerate(raw_data):
|
for i, row in enumerate(data_to_process):
|
||||||
try:
|
try:
|
||||||
board_id = int(row['board'])
|
board_id = int(row['board'])
|
||||||
player_id = int(row['player'])
|
player_id = int(row['player'])
|
||||||
@@ -245,10 +248,8 @@ class XTaskDataParser:
|
|||||||
|
|
||||||
records = []
|
records = []
|
||||||
decode_errors = 0
|
decode_errors = 0
|
||||||
|
# 抽样
|
||||||
# 限制处理数量以避免内存问题
|
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
||||||
# todo:抽样优化
|
|
||||||
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
|
|
||||||
|
|
||||||
for i, row in enumerate(data_to_process):
|
for i, row in enumerate(data_to_process):
|
||||||
try:
|
try:
|
||||||
@@ -305,8 +306,7 @@ class XTaskDataParser:
|
|||||||
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
|
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
|
||||||
records = []
|
records = []
|
||||||
decode_errors = 0
|
decode_errors = 0
|
||||||
|
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
||||||
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
|
|
||||||
|
|
||||||
for i, row in enumerate(data_to_process):
|
for i, row in enumerate(data_to_process):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -148,96 +148,3 @@ class ShortDeckHistGenerator:
|
|||||||
raise
|
raise
|
||||||
return histogram
|
return histogram
|
||||||
|
|
||||||
|
|
||||||
# # 从一副牌中抽样验证 或者 从解析的数据中抽样验证
|
|
||||||
# def generate_sample_data(self, num_samples: int = 3) -> Dict:
|
|
||||||
# print(f"\n 生成样本数据 (每阶段 {num_samples} 个样本)")
|
|
||||||
|
|
||||||
# results = {
|
|
||||||
# 'river': [],
|
|
||||||
# 'turn': [],
|
|
||||||
# 'flop': []
|
|
||||||
# }
|
|
||||||
|
|
||||||
# for i in range(num_samples):
|
|
||||||
# # River样本(7张牌)
|
|
||||||
# river_cards = np.random.choice(self.full_deck, size=7, replace=False).tolist()
|
|
||||||
# player_cards = river_cards[:2]
|
|
||||||
# board_cards = river_cards[2:7]
|
|
||||||
# river_ehs = self.generate_river_ehs(player_cards, board_cards)
|
|
||||||
# results['river'].append({
|
|
||||||
# 'player_cards': player_cards,
|
|
||||||
# 'board_cards': board_cards,
|
|
||||||
# 'ehs': river_ehs
|
|
||||||
# })
|
|
||||||
|
|
||||||
# # Turn样本(6张牌)
|
|
||||||
# turn_cards = np.random.choice(self.full_deck, size=6, replace=False).tolist()
|
|
||||||
# player_cards = turn_cards[:2]
|
|
||||||
# board_cards = turn_cards[2:6]
|
|
||||||
# turn_hist = self.generate_turn_histogram(player_cards, board_cards, num_bins=30)
|
|
||||||
# results['turn'].append({
|
|
||||||
# 'player_cards': player_cards,
|
|
||||||
# 'board_cards': board_cards,
|
|
||||||
# 'histogram': turn_hist,
|
|
||||||
# 'mean': float(np.mean(turn_hist)),
|
|
||||||
# 'std': float(np.std(turn_hist))
|
|
||||||
# })
|
|
||||||
|
|
||||||
# # Flop样本(5张牌)
|
|
||||||
# flop_cards = np.random.choice(self.full_deck, size=5, replace=False).tolist()
|
|
||||||
# player_cards = flop_cards[:2]
|
|
||||||
# board_cards = flop_cards[2:5]
|
|
||||||
# flop_hist = self.generate_flop_histogram(player_cards, board_cards, num_bins=465)
|
|
||||||
# results['flop'].append({
|
|
||||||
# 'player_cards': [str(c) for c in player_cards],
|
|
||||||
# 'board_cards': [str(c) for c in board_cards],
|
|
||||||
# 'histogram_stats': flop_hist,
|
|
||||||
# 'mean': float(np.mean(flop_hist)),
|
|
||||||
# 'std': float(np.std(flop_hist))
|
|
||||||
# })
|
|
||||||
|
|
||||||
# return results
|
|
||||||
|
|
||||||
|
|
||||||
# def main():
|
|
||||||
# """测试短牌型EHS直方图生成器"""
|
|
||||||
# print("短牌型EHS直方图生成器测试")
|
|
||||||
# print("="*60)
|
|
||||||
|
|
||||||
# generator = ShortDeckHistGenerator()
|
|
||||||
|
|
||||||
# # 测试River EHS
|
|
||||||
# print(f"\n测试River EHS计算...")
|
|
||||||
# test_cards = np.random.choice(generator.full_deck, size=7, replace=False).tolist()
|
|
||||||
# player_cards = test_cards[:2]
|
|
||||||
# board_cards = test_cards[2:7]
|
|
||||||
|
|
||||||
# player_str = " ".join(str(c) for c in player_cards)
|
|
||||||
# board_str = " ".join(str(c) for c in board_cards)
|
|
||||||
|
|
||||||
# river_ehs = generator.generate_river_ehs(player_cards, board_cards)
|
|
||||||
# print(f" 玩家底牌: [{player_str}]")
|
|
||||||
# print(f" 公共牌: [{board_str}]")
|
|
||||||
# print(f" River EHS: {river_ehs:.4f}")
|
|
||||||
|
|
||||||
# # 测试Turn直方图
|
|
||||||
# print(f"\n测试Turn直方图生成...")
|
|
||||||
# turn_cards = np.random.choice(generator.full_deck, size=6, replace=False).tolist()
|
|
||||||
# player_cards = turn_cards[:2]
|
|
||||||
# board_cards = turn_cards[2:6]
|
|
||||||
|
|
||||||
# turn_hist = generator.generate_turn_histogram(player_cards, board_cards, num_bins=30)
|
|
||||||
|
|
||||||
# # 测试Flop直方图
|
|
||||||
# print(f"\n测试Flop直方图生成...")
|
|
||||||
# flop_cards = np.random.choice(generator.full_deck, size=5, replace=False).tolist()
|
|
||||||
# player_cards = flop_cards[:2]
|
|
||||||
# board_cards = flop_cards[2:5]
|
|
||||||
|
|
||||||
# flop_hist = generator.generate_flop_histogram(player_cards, board_cards, num_bins=465)
|
|
||||||
# print(f" Flop直方图: mean={np.mean(flop_hist):.3f}, 非零bins={np.count_nonzero(flop_hist)}")
|
|
||||||
|
|
||||||
# # 生成样本数据
|
|
||||||
# print(f"\n生成样本数据...")
|
|
||||||
# sample_data = generator.generate_sample_data(num_samples=3)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user