cross-validation

This commit is contained in:
2025-10-28 15:38:25 +08:00
parent fc085eb77e
commit ec50d2897f
6 changed files with 223 additions and 854 deletions

View File

@@ -1,8 +1,5 @@
"""
Cross Validation module for EHS winrate data validation
"""
from .validator import cross_validate_main
from .cross_validation import DataValidator
from .parse_data import XTaskDataParser
__all__ = ['DataValidator', 'XTaskDataParser']
__all__ = [
'cross_validate_main'
]

View File

@@ -1,391 +0,0 @@
#!/usr/bin/env python3
import numpy as np
from typing import List, Dict, Tuple
from scipy.stats import wasserstein_distance
import sys
from .parse_data import XTaskDataParser
from shortdeck.gen_hist import ShortDeckHistGenerator
import matplotlib.pyplot as plt
class DataValidator:
def __init__(self, data_path: str = "ehs_data"):
self.parser = XTaskDataParser(data_path)
self.generator = ShortDeckHistGenerator()
print(" DataValidator初始化完成")
print(f" 生成器短牌型大小: {len(self.generator.full_deck)}")
def validate_river_samples(self, max_samples: int = 20) :
# print(f"\n 验证river_EHS样本 (最大样本数: {max_samples})")
try:
print(" 解析导出的river数据...")
print('='*60)
river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples)
if not river_records:
return {'error': '没有解析到river记录', 'success': False}
sample_records = river_records
print(f" 选择 {len(sample_records)} 个样本进行验证")
matches = 0
errors = 0
differences = []
for i, record in enumerate(sample_records):
try:
player_cards = record.player_cards
board_cards = record.board_cards
src_river_ehs = record.ehs
cur_river_ehs = self.generator.generate_river_ehs(player_cards, board_cards)
ehs_difference = abs(src_river_ehs - cur_river_ehs)
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
print(f" 原始EHS: {src_river_ehs:.6f}")
print(f" 重算EHS: {cur_river_ehs:.6f}")
print(f" 差异: {ehs_difference:.6f}")
# 判断匹配 (允许小的数值差异)
tolerance = 1e-6
if ehs_difference < tolerance:
matches += 1
else:
differences.append(ehs_difference)
except Exception as e:
errors += 1
if errors <= 3:
print(f" 样本 {i+1} 计算失败: {e}")
# 统计结果
total_samples = len(sample_records)
match_rate = matches / total_samples if total_samples > 0 else 0
mean_diff = np.mean(differences) if differences else 0
max_diff = np.max(differences) if differences else 0
result = {
'total_samples': total_samples,
'matches': matches,
'match_rate': match_rate,
'mean_difference': mean_diff,
'max_difference': max_diff,
'errors': errors,
'success': match_rate > 0.8 and mean_diff < 0.05
}
print(f" River验证完成:")
print('='*60)
print(f" 匹配数: {matches}/{total_samples} ({match_rate:.1%})")
print(f" 平均差异: {mean_diff:.6f}")
print(f" 最大差异: {max_diff:.6f}")
return result
except Exception as e:
print(f" River验证失败: {e}")
return {'error': str(e), 'success': False}
def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards):
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print("="*60)
print(f"样本 {i+1}: [{player_str}] + [{board_str}]")
print("bin src src_norm cur cur_norm")
src_hist = np.array(src_hist)
cur_hist = np.array(cur_hist)
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
for i in range(min(len(src_hist), 30)):
if src_hist[i] > 0 or cur_hist[i] > 0:
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
def validate_turn_samples(self, max_samples: int = 10):
# print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})")
try:
print(" 解析导出的Turn数据...")
print('='*60)
print('='*60)
turn_records = self.parser.parse_turn_hist_with_cards(max_records=max_samples)
if not turn_records:
return {'error': '没有解析到Turn记录', 'success': False}
print(f" 解析到 {len(turn_records)} 个Turn样本")
low_emd_count = 0
emd_distances = []
errors = 0
for i, record in enumerate(turn_records):
try:
player_cards = record.player_cards
board_cards = record.board_cards
src_hist = np.array(record.bins)
cur_hist = self.generator.generate_turn_histogram(
player_cards, board_cards, num_bins=len(src_hist)
)
cur_hist = np.array(cur_hist)
# 归一化
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
# 计算EMD距离
emd_dist = wasserstein_distance(
range(len(src_hist_norm)),
range(len(cur_hist_norm)),
src_hist_norm,
cur_hist_norm
)
low_emd_count += 1 if emd_dist < 0.2 else 0
errors += 1 if emd_dist >= 0.2 else 0
emd_distances.append(emd_dist)
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
# 画图显示
plt.plot(src_hist, label='src', marker='o')
plt.plot(cur_hist, label='cur', marker='x')
plt.title(f"turn_hist_emd={emd_dist:.6f}")
plt.xlabel("Bins")
plt.ylabel("Frequency")
plt.legend()
plt.show()
except Exception as e:
errors += 1
if errors <= 3:
print(f" 样本 {i+1} 计算失败: {e}")
# 统计结果
total_samples = len(turn_records)
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
result = {
'total_samples': total_samples,
'low_emd_count': low_emd_count,
'low_emd_rate': low_emd_rate,
'mean_emd_distance': mean_emd,
'emd_distances': emd_distances,
'errors': errors,
'success': low_emd_rate > 0.6 and mean_emd < 0.5
}
print(f" Turn验证完成:")
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
print(f" 平均EMD: {mean_emd:.6f}")
return result
except Exception as e:
print(f" Turn验证失败: {e}")
return {'error': str(e), 'success': False}
def validate_flop_samples(self, max_samples: int = 5):
# print(f"\n 验证Flop直方图样本 (最大样本数: {max_samples})")
try:
print(" 解析导出Flop数据...")
print('='*60)
print('='*60)
print('='*60)
flop_records = self.parser.parse_flop_hist_with_cards(max_records=max_samples)
if not flop_records:
return {'error': '没有解析到Flop记录', 'success': False}
print(f" 解析到 {len(flop_records)} 个Flop样本")
low_emd_count = 0
emd_distances = []
errors = 0
for i, record in enumerate(flop_records):
try:
print(f" 处理样本 {i+1}/{len(flop_records)}...")
player_cards = record.player_cards
board_cards = record.board_cards
src_hist = np.array(record.bins)
cur_hist = self.generator.generate_flop_histogram(
player_cards, board_cards, num_bins=len(src_hist)
)
cur_hist = np.array(cur_hist)
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
# 计算EMD距离
emd_dist = wasserstein_distance(
range(len(src_hist_norm)),
range(len(cur_hist_norm)),
src_hist_norm,
cur_hist_norm
)
emd_distances.append(emd_dist)
low_emd_count += 1 if emd_dist < 10 else 0
errors += 1 if emd_dist >= 10 else 0
# 显示详细信息
player_str = " ".join(str(c) for c in player_cards)
board_str = " ".join(str(c) for c in board_cards)
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
print(f" 归一化后EMD距离: {emd_dist:.6f}")
print("bin src src_norm cur cur_norm")
for i in range(min(len(src_hist), 30)):
if src_hist[i] > 0 or cur_hist[i] > 0:
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
# 画图显示
plt.plot(src_hist, label='src', marker='o')
plt.plot(cur_hist, label='cur', marker='x')
plt.title(f"flop_hist_emd={emd_dist:.6f}")
plt.xlabel("Bins")
plt.ylabel("Frequency")
plt.legend()
plt.show()
except Exception as e:
errors += 1
if errors <= 3:
print(f" 样本 {i+1} 计算失败: {e}")
# 统计结果
total_samples = len(flop_records)
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
result = {
'total_samples': total_samples,
'low_emd_count': low_emd_count,
'low_emd_rate': low_emd_rate,
'mean_emd_distance': mean_emd,
'emd_distances': emd_distances,
'errors': errors,
'success': low_emd_rate > 0.4 and mean_emd < 0.5
}
print(f" Flop验证完成:")
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
print(f" 平均EMD: {mean_emd:.6f}")
return result
except Exception as e:
print(f" Flop验证失败: {e}")
return {'error': str(e), 'success': False}
def run_full_validation(self, river_samples: int = 20, turn_samples: int = 10, flop_samples: int = 5) -> Dict:
print(" 导出数据EHS验证")
print("*"*60)
# 执行各阶段验证
results = {}
results['river'] = self.validate_river_samples(river_samples)
results['turn'] = self.validate_turn_samples(turn_samples)
results['flop'] = self.validate_flop_samples(flop_samples)
print(f"\n{'='*60}")
print(" 验证完毕")
print(f"{'='*60}")
passed_stages = 0
total_stages = 3
# River结果
print(f"\n RIVER阶段:")
if 'error' not in results['river']:
status = " 通过" if results['river']['success'] else " 失败"
print(f" 验证结果: {status}")
print(f" 样本数量: {results['river']['total_samples']}")
print(f" 匹配率: {results['river']['match_rate']:.1%}")
print(f" 平均差异: {results['river']['mean_difference']:.6f}")
if results['river']['success']:
passed_stages += 1
else:
print(f" 错误: {results['river']['error']}")
# Turn结果
print(f"\n TURN阶段:")
if 'error' not in results['turn']:
status = " 通过" if results['turn']['success'] else " 失败"
print(f" 验证结果: {status}")
print(f" 样本数量: {results['turn']['total_samples']}")
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}")
if results['turn']['success']:
passed_stages += 1
else:
print(f" 错误: {results['turn']['error']}")
# Flop结果
print(f"\n FLOP阶段:")
if 'error' not in results['flop']:
status = " 通过" if results['flop']['success'] else " 失败"
print(f" 验证结果: {status}")
print(f" 样本数量: {results['flop']['total_samples']}")
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}")
if results['flop']['success']:
passed_stages += 1
else:
print(f" 错误: {results['flop']['error']}")
# 总体结果
passed_stages = 0
total_stages = 3
if results.get('river') and results['river'].get('success', False):
passed_stages += 1
if results.get('turn') and results['turn'].get('success', False):
passed_stages += 1
if results.get('flop') and results['flop'].get('success', False):
passed_stages += 1
overall_rate = passed_stages / total_stages
# print(f"\n 总体验证通过率: {passed_stages}/{total_stages} ({overall_rate:.1%})")
# if overall_rate >= 0.7:
# print(" 数据验证成功!短牌型生成器与解析数据基本一致。")
# else:
# print(" 验证存在问题,生成器可能与实际数据不匹配,需要调试。")
return {
'results': results,
'passed_stages': passed_stages,
'total_stages': total_stages,
'overall_success': overall_rate >= 0.7
}

View File

@@ -1,429 +0,0 @@
import numpy as np
import random
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from poker.card import Card, ShortDeckRank, Suit
@dataclass
class RiverEHSRecord:
"""river_EHS """
board_id: int
player_id: int
ehs: float
board_cards: List[Card] # 5张公共牌
player_cards: List[Card] # 2张手牌
@dataclass
class TurnHistRecord:
"""turn_hist"""
board_id: int
player_id: int
bins: np.ndarray
board_cards: List[Card] # 4张公共牌
player_cards: List[Card] # 2张手牌
@dataclass
class FlopHistRecord:
"""flop_hist"""
board_id: int
player_id: int
bins: np.ndarray
board_cards: List[Card] # 3张公共牌
player_cards: List[Card] # 2张手牌
class OpenPQLDecoder:
"""open-pql 完整解码"""
def __init__(self):
# Card64
self.OFFSET_SUIT = 16 # 每个suit占16位
self.OFFSET_S = 0 # S: [15:0]
self.OFFSET_H = 16 # H: [31:16]
self.OFFSET_D = 32 # D: [47:32]
self.OFFSET_C = 48 # C: [63:48]
# open-pql 到 短牌型(6-A36张牌)
self.opql_to_poker_rank = {
# 0: Rank.TWO,
# 1: Rank.THREE,
# 2: Rank.FOUR,
# 3: Rank.FIVE,
4: ShortDeckRank.SIX,
5: ShortDeckRank.SEVEN,
6: ShortDeckRank.EIGHT,
7: ShortDeckRank.NINE,
8: ShortDeckRank.TEN,
9: ShortDeckRank.JACK,
10: ShortDeckRank.QUEEN,
11: ShortDeckRank.KING,
12: ShortDeckRank.ACE,
}
self.opql_to_poker_suit = {
0: Suit.SPADES, # S -> SPADES
1: Suit.HEARTS, # H -> HEARTS
2: Suit.DIAMONDS, # D -> DIAMONDS
3: Suit.CLUBS, # C -> CLUBS
}
print("========== OpenPQLDecoder (短牌型36张牌) ===============")
def u64_from_ranksuit(self, rank, suit) -> int:
"""对应 Card64::u64_from_ranksuit_i8"""
return 1 << rank << (suit * self.OFFSET_SUIT)
def decode_card64(self, card64_u64) -> List[Card]:
"""从 Card64 的 u64 表示解码出所有牌"""
cards = []
# 按照 Card64 解析每个suit
# 每个suit16位
for opql_suit in range(4): # 0,1,2,3
suit_offset = opql_suit * self.OFFSET_SUIT
suit_bits = (card64_u64 >> suit_offset) & 0xFFFF
# 检查这个suit的每个rank
for opql_rank in range(13): # 0-12
if suit_bits & (1 << opql_rank):
poker_rank = self.opql_to_poker_rank[opql_rank]
poker_suit = self.opql_to_poker_suit[opql_suit]
cards.append(Card(poker_rank, poker_suit))
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
return cards
def decode_hand_n_2(self, id_u16) -> List[Card]:
"""解码 Hand<2> 的 u16 ID"""
card0_u8 = id_u16 & 0xFF # 低8位
card1_u8 = (id_u16 >> 8) & 0xFF # 高8位
try:
card0 = self._card_from_u8(card0_u8)
card1 = self._card_from_u8(card1_u8)
cards = [card0, card1]
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
return cards
except (ValueError, TypeError) as e:
raise ValueError(f"解码 Hand<2> 失败: id_u16={id_u16:04x}, card0_u8={card0_u8:02x}, card1_u8={card1_u8:02x}: {e}")
def _card_from_u8(self, v) -> Card:
"""
对应 open-pql Card::from_u8()
单张牌decode
"""
SHIFT_SUIT = 4
opql_rank = v & 0b1111 # 低4位
opql_suit = v >> SHIFT_SUIT # 高4位
if opql_rank not in self.opql_to_poker_rank:
raise ValueError(f"无效的rank: {opql_rank}")
if opql_suit not in self.opql_to_poker_suit:
raise ValueError(f"无效的suit: {opql_suit}")
poker_rank = self.opql_to_poker_rank[opql_rank]
poker_suit = self.opql_to_poker_suit[opql_suit]
return Card(poker_rank, poker_suit)
def decode_board_id(self, board_id, num_cards) -> List[Card]:
"""解码公共牌 ID"""
if num_cards == 2:
# 2张牌使用 Hand<2> 的编码方式
return self.decode_hand_n_2(board_id)
else:
# 3, 4, 5张公共牌使用 Card64 的编码方式
cards = self.decode_card64(board_id)
if len(cards) != num_cards:
raise ValueError(f"解码出 {len(cards)} 张牌,期望 {num_cards} 张,board_id={board_id:016x}")
return cards
def decode_player_id(self, player_id) -> List[Card]:
"""解码玩家手牌 ID (2张牌)"""
return self.decode_hand_n_2(player_id)
def decode_board_unique_card(self, all_cards) -> List[Card]:
"""验证并返回不重复的牌面"""
unique_cards = []
for card in all_cards:
is_duplicate = False
for existing_card in unique_cards:
if card.rank == existing_card.rank and card.suit == existing_card.suit:
is_duplicate = True
break
if not is_duplicate:
unique_cards.append(card)
return unique_cards
class XTaskDataParser:
def __init__(self, data_path: str = "ehs_data"):
self.data_path = data_path
self.decoder = OpenPQLDecoder()
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]:
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
# 抽样
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
records = []
decode_errors = 0
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
ehs = float(row['ehs'])
# 解码公共牌 (5张)
board_cards = self.decoder.decode_board_id(board_id, 5)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 7:
print(f"记录 {i}: 存在重复牌面")
print(f" Board: {[str(c) for c in board_cards]}")
print(f" Player: {[str(c) for c in player_cards]}")
# 创建完整记录
record = RiverEHSRecord(
board_id=board_id,
player_id=player_id,
ehs=ehs,
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
print(f" board_id={row['board']:016x}, player_id={row['player']:04x}")
continue
# 显示进度
if (i + 1) % 10000 == 0:
success_rate = len(records) / (i + 1) * 100
print(f" 已处理 {i+1:,}/{len(raw_data):,} 条记录,成功率 {success_rate:.1f}%")
total_records = len(raw_data)
success_records = len(records)
success_rate = success_records / total_records * 100
print(f"river_EHS解析完成:")
print(f" 总记录数: {total_records:,}")
print(f" 成功解码: {success_records:,} ({success_rate:.1f}%)")
print(f" 解码失败: {decode_errors:,}")
return records
except FileNotFoundError:
print(f"文件不存在: {filepath}")
raise
except Exception as e:
print(f"解析river_EHS数据失败: {e}")
raise
def parse_turn_hist_with_cards(self, filename: str = "turn_hist.npy", max_records: int = 1000) -> List[TurnHistRecord]:
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载turn_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = []
decode_errors = 0
# 抽样
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
bins = row['bins'].copy()
# 解码公共牌 (4张)
board_cards = self.decoder.decode_board_id(board_id, 4)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 6:
decode_errors += 1
print(f"记录 {i}: 存在重复牌面")
record = TurnHistRecord(
board_id=board_id,
player_id=player_id,
bins=bins, # 存储30个bins的numpy数组
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
continue
if (i + 1) % 100 == 0:
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
print(f"turn_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
return records
except Exception as e:
print(f" 解析turn_hist数据失败: {e}")
raise
def parse_flop_hist_with_cards(self, filename: str = "flop_hist.npy", max_records: int = 500) -> List[FlopHistRecord]:
"""解析flop_hist"""
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = []
decode_errors = 0
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
bins = row['bins'].copy()
# 解码公共牌 (3张)
board_cards = self.decoder.decode_board_id(board_id, 3)
# 解码玩家手牌 (2张)
player_cards = self.decoder.decode_player_id(player_id)
# 验证牌面不重复
all_cards = board_cards + player_cards
unique_cards = self.decoder.decode_board_unique_card(all_cards)
if len(unique_cards) != 5:
decode_errors += 1
print(f"记录 {i}: 存在重复牌面")
record = FlopHistRecord(
board_id=board_id,
player_id=player_id,
bins=bins, # 存储465个bins的numpy数组
board_cards=board_cards,
player_cards=player_cards
)
records.append(record)
except Exception as e:
decode_errors += 1
if decode_errors <= 3:
print(f"记录 {i} 解码失败: {e}")
continue
if (i + 1) % 50 == 0:
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
print(f"flop_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
return records
except Exception as e:
print(f"解析flop_hist数据失败: {e}")
raise
def analyze_parsed_data():
print("开始解析 xtask 数据并解码牌面...\n")
try:
parser = XTaskDataParser(data_path="ehs_data")
# 1. 解析river_EHS数据
print("=" * 60)
print(" 解析river_EHS")
print("=" * 60)
river_records = parser.parse_river_ehs_with_cards()
# 显示river_EHS样本
print(f"\nriver_EHS数据样本 (前10条):")
for i, record in enumerate(river_records[:10]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
print(f" {i+1:2d}. Board:[{board_str:14s}] Player:[{player_str:5s}] EHS:{record.ehs:.4f}")
# 2. 解析turn_hist数据
print(f"\n" + "=" * 60)
print(" 解析turn_hist数据")
print("=" * 60)
turn_records = parser.parse_turn_hist_with_cards(max_records=100) # 限制数量
# 显示turn_hist数据样本
print(f"\nturn_hist数据样本 (前5条):")
for i, record in enumerate(turn_records[:5]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
print(f" {i+1}. Board:[{board_str:11s}] Player:[{player_str:5s}] Bins:{bins_stats}")
# 3. 解析flop_hist数据
print(f"\n" + "=" * 60)
print(" 解析flop_hist数据")
print("=" * 60)
flop_records = parser.parse_flop_hist_with_cards(max_records=50) # 限制数量
# 显示flop_数据样本
print(f"\n flop_hist数据样本 (前5条):")
for i, record in enumerate(flop_records[:5]):
board_str = " ".join(str(c) for c in record.board_cards)
player_str = " ".join(str(c) for c in record.player_cards)
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
print(f" {i+1}. Board:[{board_str:8s}] Player:[{player_str:5s}] Bins:{bins_stats}")
# 4. 统计摘要
print(f"\n" + "=" * 60)
print("解析统计摘要")
print("=" * 60)
print(f"river_EHS记录: {len(river_records):,}")
print(f"turn_hist记录: {len(turn_records):,}")
print(f"flop_hist记录: {len(flop_records):,}")
# EHS 统计
if river_records:
ehs_values = [r.ehs for r in river_records]
print(f"river_EHS统计:")
print(f" 范围: [{min(ehs_values):.4f}, {max(ehs_values):.4f}]")
print(f" 均值: {np.mean(ehs_values):.4f}")
print(f" 标准差: {np.std(ehs_values):.4f}")
print(f"\n数据解析完成!所有 Board ID 和 Player ID 已成功解码为具体牌面。")
return {
'river_records': river_records,
'turn_records': turn_records,
'flop_records': flop_records
}
except Exception as e:
print(f" 数据解析失败: {e}")
return None

View File

@@ -0,0 +1,216 @@
import itertools
import numpy as np
import random
from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path
from poker import Suit, Card
from shortdeck import ShortDeckHandEvaluator as HE
from shortdeck import ShortDeckRank as SDR
data_path = Path(".") / "ehs-data"
np_river = np.load(data_path / "river_ehs_sd.npy")
np_turn = np.load(data_path / "turn_hist_sd.npy")
np_flop = np.load(data_path / "flop_hist_sd.npy")
cards = [Card(r, s) for r in SDR for s in Suit]
CARD_BITS = 6
class SuitMapping:
def __init__(self):
self.mapping = {}
self.suits = list(reversed(Suit))
def map_suit(self, s: Suit) -> Suit:
if s not in self.mapping:
self.mapping[s] = self.suits.pop()
return self.mapping[s]
class EhsCache:
def __init__(self):
self.cache = defaultdict(lambda: defaultdict(dict))
def _set_keys(self, flop, player):
suit_map = SuitMapping()
complex_cards = player+flop
iso_complex = to_iso(complex_cards, suit_map)
complex_key = cards_to_u32(iso_complex)
return complex_key
# 全部存下来计算耗时太大了。。嗯。。
# todo:不存储直接抽样计算
def store_river_ehs(self, player, board, ehs):
complex_key = self._set_keys(board[:3],player)
turn_idx = card_index(board[3])
river_idx = card_index(board[4])
self.cache[complex_key][turn_idx][river_idx] = ehs
def get_turn_hist(self, player, flop, turn):
complex_key = self._set_keys(flop, player)
turn_idx = card_index(turn)
turn_hist = self.cache[complex_key][turn_idx]
return list(turn_hist.values()) if turn_hist else None
def get_flop_hist(self, player, flop):
complex_key = self._set_keys(flop, player)
all_ehs = []
player_data = self.cache[complex_key]
for turn_idx in player_data:
for river_idx in player_data[turn_idx]:
all_ehs.append(player_data[turn_idx][river_idx])
return all_ehs if len(all_ehs) == 465 else None
def get_rank_idx(rank: SDR) -> int:
rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN,
SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE]
return rank_order.index(rank)
def get_suit_idx(suit: Suit) -> int:
suit_order = [Suit.SPADES, Suit.HEARTS, Suit.DIAMONDS, Suit.CLUBS]
return suit_order.index(suit)
def card_index(card: Card) -> int:
return (get_rank_idx(card.rank) + 4) * 4 + get_suit_idx(card.suit)
Card.__eq__ = lambda a, b: (a.rank == b.rank) and (a.suit == b.suit)
Card.__hash__ = lambda a: hash((get_rank_idx(a.rank), get_suit_idx(a.suit)))
def cards_to_u32(cards: list[Card]) -> int:
res = 0
for i, card in enumerate(cards):
bits = card_index(card) & 0x3F
res |= bits << (i * CARD_BITS)
return res
def to_iso(cards: list[Card], mapping: SuitMapping) -> list[Card]:
def count_suit(card: Card) -> int:
return sum(1 for other in cards if other.suit == card.suit)
sorted_cards = sorted(
cards,
key=lambda c: (count_suit(c), get_rank_idx(c.rank), get_suit_idx(c.suit))
)
res = []
for card in sorted_cards:
mapped_suit = mapping.map_suit(card.suit)
res.append(Card(card.rank, mapped_suit))
return sorted(res, key=lambda c: (get_rank_idx(c.rank), get_suit_idx(c.suit)))
def cards_to_u16(cards: list[Card]) -> int:
res = 0
for i, card in enumerate(cards):
bits = card_index(card) & 0x3F
res |= bits << (i * CARD_BITS)
return res
def calc_river_ehs(board: list[Card], player: list[Card]) -> float:
player_hand = [*board, *player]
player_ranking = HE.evaluate_hand(player_hand)
acc = 0
sum = 0
for other in itertools.combinations(cards, 2):
if set(other) & set(player_hand):
continue
if set(other) & set(board):
continue
other_ranking = HE.evaluate_hand([*board, *other])
if player_ranking == other_ranking:
acc += 1
elif player_ranking > other_ranking:
acc += 2
sum += 2
return acc / sum
def get_data(board: list[Card], player: list[Card]):
def _get_data(data, board: list[Card], player: list[Card]):
suit_map = SuitMapping()
iso_board = to_iso(board, suit_map)
iso_player = to_iso(player, suit_map)
mask_board = data["board"] == cards_to_u32(iso_board)
mask_player = data["player"] == cards_to_u16(iso_player)
return data[mask_board & mask_player][0][2]
match len(board):
case 3:
return _get_data(np_flop, board, player)
case 4:
return _get_data(np_turn, board, player)
case 5:
return _get_data(np_river, board, player)
case _:
raise NotImplementedError
def euclidean_dist(left, right):
if isinstance(left, Iterable):
v1 = np.sort(np.array(left, dtype=np.float32))
v2 = np.sort(np.array(right, dtype=np.float32))
return np.linalg.norm(v2 - v1)
else:
return np.abs(left - right) ** 2
def compare_data(sampled, board, player):
err_count = 0
d = euclidean_dist(get_data(board, player), sampled)
if not np.isclose(d, 0.0):
print(f"[{''.join(map(str, board))} {''.join(map(str, player))}]: {d}")
err_count += 1
return err_count
card_ehs = defaultdict(dict)
def validate_river():
validated_count = 0
error_count = 0
for river_combo in itertools.combinations(cards, 5):
board = list(river_combo)
unused_cards = [c for c in cards if c not in board]
for player_combo in itertools.combinations(unused_cards, 2):
player = list(player_combo)
ehs = calc_river_ehs(board, player)
ehs_stored.store_river_ehs(player, board, ehs)
error_count = compare_data(ehs, board, player)
validated_count += 1
print(".", end="", flush=True)
print(f"river validate count {validated_count}")
print(f"Validated river hands: {validated_count}, Errors: {error_count}")
def validate_turn():
validated_count = 0
error_count = 0
for turn_combo in itertools.combinations(cards, 4):
turn = list(turn_combo)
unused_cards = [c for c in cards if c not in turn]
for player_combo in itertools.combinations(unused_cards, 2):
player = list(player_combo)
turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3])
error_count += compare_data(turn_hist, turn, player)
validated_count += 1
print(".", end="", flush=True)
print(f"Validated turn hands: {validated_count}, Errors: {error_count}")
def validate_flop():
sample = random.sample(cards, 5)
flop = sample[2:]
player = sample[:2]
flop_hist = ehs_stored.get_flop_hist(player, flop)
if flop_hist is None:
return
compare_data(flop_hist, flop, player)
ehs_stored = EhsCache()
def cross_validate_main():
validate_river()
validate_turn()
validate_flop()
if __name__ == "__main__":
cross_validate_main()