cross-validation
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,4 @@
|
|||||||
# Python-generated files
|
ehs-data
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[oc]
|
*.py[oc]
|
||||||
build/
|
build/
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
"""
|
from .validator import cross_validate_main
|
||||||
Cross Validation module for EHS winrate data validation
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .cross_validation import DataValidator
|
__all__ = [
|
||||||
from .parse_data import XTaskDataParser
|
'cross_validate_main'
|
||||||
|
]
|
||||||
__all__ = ['DataValidator', 'XTaskDataParser']
|
|
||||||
@@ -1,391 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import numpy as np
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
from scipy.stats import wasserstein_distance
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from .parse_data import XTaskDataParser
|
|
||||||
from shortdeck.gen_hist import ShortDeckHistGenerator
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
class DataValidator:
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, data_path: str = "ehs_data"):
|
|
||||||
self.parser = XTaskDataParser(data_path)
|
|
||||||
self.generator = ShortDeckHistGenerator()
|
|
||||||
print(" DataValidator初始化完成")
|
|
||||||
print(f" 生成器短牌型大小: {len(self.generator.full_deck)}")
|
|
||||||
|
|
||||||
def validate_river_samples(self, max_samples: int = 20) :
|
|
||||||
# print(f"\n 验证river_EHS样本 (最大样本数: {max_samples})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
print(" 解析导出的river数据...")
|
|
||||||
print('='*60)
|
|
||||||
river_records = self.parser.parse_river_ehs_with_cards(max_records=max_samples)
|
|
||||||
|
|
||||||
if not river_records:
|
|
||||||
return {'error': '没有解析到river记录', 'success': False}
|
|
||||||
|
|
||||||
sample_records = river_records
|
|
||||||
print(f" 选择 {len(sample_records)} 个样本进行验证")
|
|
||||||
|
|
||||||
matches = 0
|
|
||||||
errors = 0
|
|
||||||
differences = []
|
|
||||||
|
|
||||||
for i, record in enumerate(sample_records):
|
|
||||||
try:
|
|
||||||
player_cards = record.player_cards
|
|
||||||
board_cards = record.board_cards
|
|
||||||
src_river_ehs = record.ehs
|
|
||||||
|
|
||||||
|
|
||||||
cur_river_ehs = self.generator.generate_river_ehs(player_cards, board_cards)
|
|
||||||
|
|
||||||
|
|
||||||
ehs_difference = abs(src_river_ehs - cur_river_ehs)
|
|
||||||
|
|
||||||
|
|
||||||
player_str = " ".join(str(c) for c in player_cards)
|
|
||||||
board_str = " ".join(str(c) for c in board_cards)
|
|
||||||
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
|
||||||
print(f" 原始EHS: {src_river_ehs:.6f}")
|
|
||||||
print(f" 重算EHS: {cur_river_ehs:.6f}")
|
|
||||||
print(f" 差异: {ehs_difference:.6f}")
|
|
||||||
|
|
||||||
# 判断匹配 (允许小的数值差异)
|
|
||||||
tolerance = 1e-6
|
|
||||||
if ehs_difference < tolerance:
|
|
||||||
matches += 1
|
|
||||||
else:
|
|
||||||
differences.append(ehs_difference)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
errors += 1
|
|
||||||
if errors <= 3:
|
|
||||||
print(f" 样本 {i+1} 计算失败: {e}")
|
|
||||||
|
|
||||||
# 统计结果
|
|
||||||
total_samples = len(sample_records)
|
|
||||||
match_rate = matches / total_samples if total_samples > 0 else 0
|
|
||||||
mean_diff = np.mean(differences) if differences else 0
|
|
||||||
max_diff = np.max(differences) if differences else 0
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'total_samples': total_samples,
|
|
||||||
'matches': matches,
|
|
||||||
'match_rate': match_rate,
|
|
||||||
'mean_difference': mean_diff,
|
|
||||||
'max_difference': max_diff,
|
|
||||||
'errors': errors,
|
|
||||||
'success': match_rate > 0.8 and mean_diff < 0.05
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f" River验证完成:")
|
|
||||||
print('='*60)
|
|
||||||
print(f" 匹配数: {matches}/{total_samples} ({match_rate:.1%})")
|
|
||||||
print(f" 平均差异: {mean_diff:.6f}")
|
|
||||||
print(f" 最大差异: {max_diff:.6f}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" River验证失败: {e}")
|
|
||||||
return {'error': str(e), 'success': False}
|
|
||||||
|
|
||||||
def print_sample_record(self, i, src_hist, cur_hist, player_cards, board_cards):
|
|
||||||
|
|
||||||
player_str = " ".join(str(c) for c in player_cards)
|
|
||||||
board_str = " ".join(str(c) for c in board_cards)
|
|
||||||
print("="*60)
|
|
||||||
print(f"样本 {i+1}: [{player_str}] + [{board_str}]")
|
|
||||||
print("bin src src_norm cur cur_norm")
|
|
||||||
src_hist = np.array(src_hist)
|
|
||||||
cur_hist = np.array(cur_hist)
|
|
||||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
|
||||||
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
|
||||||
for i in range(min(len(src_hist), 30)):
|
|
||||||
if src_hist[i] > 0 or cur_hist[i] > 0:
|
|
||||||
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
|
||||||
|
|
||||||
def validate_turn_samples(self, max_samples: int = 10):
|
|
||||||
|
|
||||||
# print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
print(" 解析导出的Turn数据...")
|
|
||||||
print('='*60)
|
|
||||||
print('='*60)
|
|
||||||
turn_records = self.parser.parse_turn_hist_with_cards(max_records=max_samples)
|
|
||||||
|
|
||||||
if not turn_records:
|
|
||||||
return {'error': '没有解析到Turn记录', 'success': False}
|
|
||||||
|
|
||||||
print(f" 解析到 {len(turn_records)} 个Turn样本")
|
|
||||||
|
|
||||||
|
|
||||||
low_emd_count = 0
|
|
||||||
emd_distances = []
|
|
||||||
errors = 0
|
|
||||||
|
|
||||||
for i, record in enumerate(turn_records):
|
|
||||||
try:
|
|
||||||
|
|
||||||
player_cards = record.player_cards
|
|
||||||
board_cards = record.board_cards
|
|
||||||
src_hist = np.array(record.bins)
|
|
||||||
|
|
||||||
cur_hist = self.generator.generate_turn_histogram(
|
|
||||||
player_cards, board_cards, num_bins=len(src_hist)
|
|
||||||
)
|
|
||||||
cur_hist = np.array(cur_hist)
|
|
||||||
|
|
||||||
# 归一化
|
|
||||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
|
||||||
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
|
||||||
|
|
||||||
# 计算EMD距离
|
|
||||||
emd_dist = wasserstein_distance(
|
|
||||||
range(len(src_hist_norm)),
|
|
||||||
range(len(cur_hist_norm)),
|
|
||||||
src_hist_norm,
|
|
||||||
cur_hist_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
low_emd_count += 1 if emd_dist < 0.2 else 0
|
|
||||||
errors += 1 if emd_dist >= 0.2 else 0
|
|
||||||
emd_distances.append(emd_dist)
|
|
||||||
|
|
||||||
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 画图显示
|
|
||||||
plt.plot(src_hist, label='src', marker='o')
|
|
||||||
plt.plot(cur_hist, label='cur', marker='x')
|
|
||||||
plt.title(f"turn_hist_emd={emd_dist:.6f}")
|
|
||||||
plt.xlabel("Bins")
|
|
||||||
plt.ylabel("Frequency")
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
errors += 1
|
|
||||||
if errors <= 3:
|
|
||||||
print(f" 样本 {i+1} 计算失败: {e}")
|
|
||||||
|
|
||||||
# 统计结果
|
|
||||||
total_samples = len(turn_records)
|
|
||||||
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
|
|
||||||
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'total_samples': total_samples,
|
|
||||||
'low_emd_count': low_emd_count,
|
|
||||||
'low_emd_rate': low_emd_rate,
|
|
||||||
'mean_emd_distance': mean_emd,
|
|
||||||
'emd_distances': emd_distances,
|
|
||||||
'errors': errors,
|
|
||||||
'success': low_emd_rate > 0.6 and mean_emd < 0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f" Turn验证完成:")
|
|
||||||
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
|
|
||||||
print(f" 平均EMD: {mean_emd:.6f}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Turn验证失败: {e}")
|
|
||||||
return {'error': str(e), 'success': False}
|
|
||||||
|
|
||||||
def validate_flop_samples(self, max_samples: int = 5):
|
|
||||||
# print(f"\n 验证Flop直方图样本 (最大样本数: {max_samples})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
print(" 解析导出Flop数据...")
|
|
||||||
print('='*60)
|
|
||||||
print('='*60)
|
|
||||||
print('='*60)
|
|
||||||
flop_records = self.parser.parse_flop_hist_with_cards(max_records=max_samples)
|
|
||||||
|
|
||||||
if not flop_records:
|
|
||||||
return {'error': '没有解析到Flop记录', 'success': False}
|
|
||||||
|
|
||||||
print(f" 解析到 {len(flop_records)} 个Flop样本")
|
|
||||||
|
|
||||||
low_emd_count = 0
|
|
||||||
emd_distances = []
|
|
||||||
errors = 0
|
|
||||||
|
|
||||||
for i, record in enumerate(flop_records):
|
|
||||||
try:
|
|
||||||
print(f" 处理样本 {i+1}/{len(flop_records)}...")
|
|
||||||
|
|
||||||
player_cards = record.player_cards
|
|
||||||
board_cards = record.board_cards
|
|
||||||
src_hist = np.array(record.bins)
|
|
||||||
|
|
||||||
cur_hist = self.generator.generate_flop_histogram(
|
|
||||||
player_cards, board_cards, num_bins=len(src_hist)
|
|
||||||
)
|
|
||||||
cur_hist = np.array(cur_hist)
|
|
||||||
|
|
||||||
|
|
||||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
|
||||||
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
|
||||||
|
|
||||||
# 计算EMD距离
|
|
||||||
emd_dist = wasserstein_distance(
|
|
||||||
range(len(src_hist_norm)),
|
|
||||||
range(len(cur_hist_norm)),
|
|
||||||
src_hist_norm,
|
|
||||||
cur_hist_norm
|
|
||||||
)
|
|
||||||
emd_distances.append(emd_dist)
|
|
||||||
|
|
||||||
low_emd_count += 1 if emd_dist < 10 else 0
|
|
||||||
errors += 1 if emd_dist >= 10 else 0
|
|
||||||
|
|
||||||
# 显示详细信息
|
|
||||||
player_str = " ".join(str(c) for c in player_cards)
|
|
||||||
board_str = " ".join(str(c) for c in board_cards)
|
|
||||||
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
|
||||||
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
|
||||||
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
|
||||||
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
|
||||||
print("bin src src_norm cur cur_norm")
|
|
||||||
for i in range(min(len(src_hist), 30)):
|
|
||||||
if src_hist[i] > 0 or cur_hist[i] > 0:
|
|
||||||
print(f"bin[{i}], {src_hist[i]:8.3f}, {src_hist_norm[i]:8.3f}, {cur_hist[i]:8.3f}, {cur_hist_norm[i]:8.3f}")
|
|
||||||
|
|
||||||
self.print_sample_record(i, src_hist, cur_hist, player_cards, board_cards)
|
|
||||||
|
|
||||||
# 画图显示
|
|
||||||
plt.plot(src_hist, label='src', marker='o')
|
|
||||||
plt.plot(cur_hist, label='cur', marker='x')
|
|
||||||
plt.title(f"flop_hist_emd={emd_dist:.6f}")
|
|
||||||
plt.xlabel("Bins")
|
|
||||||
plt.ylabel("Frequency")
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
errors += 1
|
|
||||||
if errors <= 3:
|
|
||||||
print(f" 样本 {i+1} 计算失败: {e}")
|
|
||||||
|
|
||||||
# 统计结果
|
|
||||||
total_samples = len(flop_records)
|
|
||||||
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
|
|
||||||
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'total_samples': total_samples,
|
|
||||||
'low_emd_count': low_emd_count,
|
|
||||||
'low_emd_rate': low_emd_rate,
|
|
||||||
'mean_emd_distance': mean_emd,
|
|
||||||
'emd_distances': emd_distances,
|
|
||||||
'errors': errors,
|
|
||||||
'success': low_emd_rate > 0.4 and mean_emd < 0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f" Flop验证完成:")
|
|
||||||
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
|
|
||||||
print(f" 平均EMD: {mean_emd:.6f}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Flop验证失败: {e}")
|
|
||||||
return {'error': str(e), 'success': False}
|
|
||||||
|
|
||||||
def run_full_validation(self, river_samples: int = 20, turn_samples: int = 10, flop_samples: int = 5) -> Dict:
|
|
||||||
|
|
||||||
print(" 导出数据EHS验证")
|
|
||||||
print("*"*60)
|
|
||||||
|
|
||||||
# 执行各阶段验证
|
|
||||||
results = {}
|
|
||||||
results['river'] = self.validate_river_samples(river_samples)
|
|
||||||
results['turn'] = self.validate_turn_samples(turn_samples)
|
|
||||||
results['flop'] = self.validate_flop_samples(flop_samples)
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(" 验证完毕")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
passed_stages = 0
|
|
||||||
total_stages = 3
|
|
||||||
|
|
||||||
# River结果
|
|
||||||
print(f"\n RIVER阶段:")
|
|
||||||
if 'error' not in results['river']:
|
|
||||||
status = " 通过" if results['river']['success'] else " 失败"
|
|
||||||
print(f" 验证结果: {status}")
|
|
||||||
print(f" 样本数量: {results['river']['total_samples']}")
|
|
||||||
print(f" 匹配率: {results['river']['match_rate']:.1%}")
|
|
||||||
print(f" 平均差异: {results['river']['mean_difference']:.6f}")
|
|
||||||
if results['river']['success']:
|
|
||||||
passed_stages += 1
|
|
||||||
else:
|
|
||||||
print(f" 错误: {results['river']['error']}")
|
|
||||||
|
|
||||||
# Turn结果
|
|
||||||
print(f"\n TURN阶段:")
|
|
||||||
if 'error' not in results['turn']:
|
|
||||||
status = " 通过" if results['turn']['success'] else " 失败"
|
|
||||||
print(f" 验证结果: {status}")
|
|
||||||
print(f" 样本数量: {results['turn']['total_samples']}")
|
|
||||||
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
|
|
||||||
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
|
|
||||||
print(f" 抽样EMD: {[emd for emd in results['turn']['emd_distances'][:5]]}")
|
|
||||||
if results['turn']['success']:
|
|
||||||
passed_stages += 1
|
|
||||||
else:
|
|
||||||
print(f" 错误: {results['turn']['error']}")
|
|
||||||
|
|
||||||
# Flop结果
|
|
||||||
print(f"\n FLOP阶段:")
|
|
||||||
if 'error' not in results['flop']:
|
|
||||||
status = " 通过" if results['flop']['success'] else " 失败"
|
|
||||||
print(f" 验证结果: {status}")
|
|
||||||
print(f" 样本数量: {results['flop']['total_samples']}")
|
|
||||||
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
|
|
||||||
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
|
|
||||||
print(f" 抽样EMD: {[emd for emd in results['flop']['emd_distances'][:5]]}")
|
|
||||||
if results['flop']['success']:
|
|
||||||
passed_stages += 1
|
|
||||||
else:
|
|
||||||
print(f" 错误: {results['flop']['error']}")
|
|
||||||
|
|
||||||
# 总体结果
|
|
||||||
passed_stages = 0
|
|
||||||
total_stages = 3
|
|
||||||
|
|
||||||
if results.get('river') and results['river'].get('success', False):
|
|
||||||
passed_stages += 1
|
|
||||||
if results.get('turn') and results['turn'].get('success', False):
|
|
||||||
passed_stages += 1
|
|
||||||
if results.get('flop') and results['flop'].get('success', False):
|
|
||||||
passed_stages += 1
|
|
||||||
|
|
||||||
overall_rate = passed_stages / total_stages
|
|
||||||
# print(f"\n 总体验证通过率: {passed_stages}/{total_stages} ({overall_rate:.1%})")
|
|
||||||
|
|
||||||
# if overall_rate >= 0.7:
|
|
||||||
# print(" 数据验证成功!短牌型生成器与解析数据基本一致。")
|
|
||||||
# else:
|
|
||||||
# print(" 验证存在问题,生成器可能与实际数据不匹配,需要调试。")
|
|
||||||
|
|
||||||
return {
|
|
||||||
'results': results,
|
|
||||||
'passed_stages': passed_stages,
|
|
||||||
'total_stages': total_stages,
|
|
||||||
'overall_success': overall_rate >= 0.7
|
|
||||||
}
|
|
||||||
@@ -1,429 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
from typing import List, Dict, Tuple, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from poker.card import Card, ShortDeckRank, Suit
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RiverEHSRecord:
|
|
||||||
"""river_EHS """
|
|
||||||
board_id: int
|
|
||||||
player_id: int
|
|
||||||
ehs: float
|
|
||||||
board_cards: List[Card] # 5张公共牌
|
|
||||||
player_cards: List[Card] # 2张手牌
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TurnHistRecord:
|
|
||||||
"""turn_hist"""
|
|
||||||
board_id: int
|
|
||||||
player_id: int
|
|
||||||
bins: np.ndarray
|
|
||||||
board_cards: List[Card] # 4张公共牌
|
|
||||||
player_cards: List[Card] # 2张手牌
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FlopHistRecord:
|
|
||||||
"""flop_hist"""
|
|
||||||
board_id: int
|
|
||||||
player_id: int
|
|
||||||
bins: np.ndarray
|
|
||||||
board_cards: List[Card] # 3张公共牌
|
|
||||||
player_cards: List[Card] # 2张手牌
|
|
||||||
|
|
||||||
class OpenPQLDecoder:
|
|
||||||
"""open-pql 完整解码"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Card64
|
|
||||||
self.OFFSET_SUIT = 16 # 每个suit占16位
|
|
||||||
self.OFFSET_S = 0 # S: [15:0]
|
|
||||||
self.OFFSET_H = 16 # H: [31:16]
|
|
||||||
self.OFFSET_D = 32 # D: [47:32]
|
|
||||||
self.OFFSET_C = 48 # C: [63:48]
|
|
||||||
|
|
||||||
# open-pql 到 短牌型(6-A,36张牌)
|
|
||||||
self.opql_to_poker_rank = {
|
|
||||||
# 0: Rank.TWO,
|
|
||||||
# 1: Rank.THREE,
|
|
||||||
# 2: Rank.FOUR,
|
|
||||||
# 3: Rank.FIVE,
|
|
||||||
4: ShortDeckRank.SIX,
|
|
||||||
5: ShortDeckRank.SEVEN,
|
|
||||||
6: ShortDeckRank.EIGHT,
|
|
||||||
7: ShortDeckRank.NINE,
|
|
||||||
8: ShortDeckRank.TEN,
|
|
||||||
9: ShortDeckRank.JACK,
|
|
||||||
10: ShortDeckRank.QUEEN,
|
|
||||||
11: ShortDeckRank.KING,
|
|
||||||
12: ShortDeckRank.ACE,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.opql_to_poker_suit = {
|
|
||||||
0: Suit.SPADES, # S -> SPADES
|
|
||||||
1: Suit.HEARTS, # H -> HEARTS
|
|
||||||
2: Suit.DIAMONDS, # D -> DIAMONDS
|
|
||||||
3: Suit.CLUBS, # C -> CLUBS
|
|
||||||
}
|
|
||||||
|
|
||||||
print("========== OpenPQLDecoder (短牌型36张牌) ===============")
|
|
||||||
|
|
||||||
def u64_from_ranksuit(self, rank, suit) -> int:
|
|
||||||
"""对应 Card64::u64_from_ranksuit_i8"""
|
|
||||||
return 1 << rank << (suit * self.OFFSET_SUIT)
|
|
||||||
|
|
||||||
def decode_card64(self, card64_u64) -> List[Card]:
|
|
||||||
"""从 Card64 的 u64 表示解码出所有牌"""
|
|
||||||
cards = []
|
|
||||||
|
|
||||||
# 按照 Card64 解析每个suit
|
|
||||||
# 每个suit16位
|
|
||||||
for opql_suit in range(4): # 0,1,2,3
|
|
||||||
suit_offset = opql_suit * self.OFFSET_SUIT
|
|
||||||
suit_bits = (card64_u64 >> suit_offset) & 0xFFFF
|
|
||||||
|
|
||||||
# 检查这个suit的每个rank
|
|
||||||
for opql_rank in range(13): # 0-12
|
|
||||||
if suit_bits & (1 << opql_rank):
|
|
||||||
poker_rank = self.opql_to_poker_rank[opql_rank]
|
|
||||||
poker_suit = self.opql_to_poker_suit[opql_suit]
|
|
||||||
cards.append(Card(poker_rank, poker_suit))
|
|
||||||
|
|
||||||
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
|
|
||||||
return cards
|
|
||||||
|
|
||||||
def decode_hand_n_2(self, id_u16) -> List[Card]:
|
|
||||||
"""解码 Hand<2> 的 u16 ID"""
|
|
||||||
card0_u8 = id_u16 & 0xFF # 低8位
|
|
||||||
card1_u8 = (id_u16 >> 8) & 0xFF # 高8位
|
|
||||||
|
|
||||||
try:
|
|
||||||
card0 = self._card_from_u8(card0_u8)
|
|
||||||
card1 = self._card_from_u8(card1_u8)
|
|
||||||
|
|
||||||
cards = [card0, card1]
|
|
||||||
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
|
|
||||||
|
|
||||||
return cards
|
|
||||||
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
raise ValueError(f"解码 Hand<2> 失败: id_u16={id_u16:04x}, card0_u8={card0_u8:02x}, card1_u8={card1_u8:02x}: {e}")
|
|
||||||
|
|
||||||
def _card_from_u8(self, v) -> Card:
|
|
||||||
"""
|
|
||||||
对应 open-pql Card::from_u8()
|
|
||||||
单张牌decode
|
|
||||||
"""
|
|
||||||
SHIFT_SUIT = 4
|
|
||||||
opql_rank = v & 0b1111 # 低4位
|
|
||||||
opql_suit = v >> SHIFT_SUIT # 高4位
|
|
||||||
|
|
||||||
if opql_rank not in self.opql_to_poker_rank:
|
|
||||||
raise ValueError(f"无效的rank: {opql_rank}")
|
|
||||||
if opql_suit not in self.opql_to_poker_suit:
|
|
||||||
raise ValueError(f"无效的suit: {opql_suit}")
|
|
||||||
|
|
||||||
poker_rank = self.opql_to_poker_rank[opql_rank]
|
|
||||||
poker_suit = self.opql_to_poker_suit[opql_suit]
|
|
||||||
|
|
||||||
return Card(poker_rank, poker_suit)
|
|
||||||
|
|
||||||
def decode_board_id(self, board_id, num_cards) -> List[Card]:
|
|
||||||
"""解码公共牌 ID"""
|
|
||||||
if num_cards == 2:
|
|
||||||
# 2张牌使用 Hand<2> 的编码方式
|
|
||||||
return self.decode_hand_n_2(board_id)
|
|
||||||
else:
|
|
||||||
# 3, 4, 5张公共牌使用 Card64 的编码方式
|
|
||||||
cards = self.decode_card64(board_id)
|
|
||||||
|
|
||||||
if len(cards) != num_cards:
|
|
||||||
raise ValueError(f"解码出 {len(cards)} 张牌,期望 {num_cards} 张,board_id={board_id:016x}")
|
|
||||||
|
|
||||||
return cards
|
|
||||||
|
|
||||||
def decode_player_id(self, player_id) -> List[Card]:
|
|
||||||
"""解码玩家手牌 ID (2张牌)"""
|
|
||||||
return self.decode_hand_n_2(player_id)
|
|
||||||
|
|
||||||
def decode_board_unique_card(self, all_cards) -> List[Card]:
|
|
||||||
"""验证并返回不重复的牌面"""
|
|
||||||
unique_cards = []
|
|
||||||
for card in all_cards:
|
|
||||||
is_duplicate = False
|
|
||||||
for existing_card in unique_cards:
|
|
||||||
if card.rank == existing_card.rank and card.suit == existing_card.suit:
|
|
||||||
is_duplicate = True
|
|
||||||
break
|
|
||||||
if not is_duplicate:
|
|
||||||
unique_cards.append(card)
|
|
||||||
return unique_cards
|
|
||||||
|
|
||||||
|
|
||||||
class XTaskDataParser:
|
|
||||||
def __init__(self, data_path: str = "ehs_data"):
|
|
||||||
self.data_path = data_path
|
|
||||||
self.decoder = OpenPQLDecoder()
|
|
||||||
|
|
||||||
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]:
|
|
||||||
filepath = f"{self.data_path}/{filename}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw_data = np.load(filepath)
|
|
||||||
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
|
|
||||||
# 抽样
|
|
||||||
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
|
||||||
|
|
||||||
records = []
|
|
||||||
decode_errors = 0
|
|
||||||
|
|
||||||
for i, row in enumerate(data_to_process):
|
|
||||||
try:
|
|
||||||
board_id = int(row['board'])
|
|
||||||
player_id = int(row['player'])
|
|
||||||
ehs = float(row['ehs'])
|
|
||||||
|
|
||||||
# 解码公共牌 (5张)
|
|
||||||
board_cards = self.decoder.decode_board_id(board_id, 5)
|
|
||||||
|
|
||||||
# 解码玩家手牌 (2张)
|
|
||||||
player_cards = self.decoder.decode_player_id(player_id)
|
|
||||||
|
|
||||||
# 验证牌面不重复
|
|
||||||
all_cards = board_cards + player_cards
|
|
||||||
|
|
||||||
unique_cards = self.decoder.decode_board_unique_card(all_cards)
|
|
||||||
if len(unique_cards) != 7:
|
|
||||||
print(f"记录 {i}: 存在重复牌面")
|
|
||||||
print(f" Board: {[str(c) for c in board_cards]}")
|
|
||||||
print(f" Player: {[str(c) for c in player_cards]}")
|
|
||||||
|
|
||||||
# 创建完整记录
|
|
||||||
record = RiverEHSRecord(
|
|
||||||
board_id=board_id,
|
|
||||||
player_id=player_id,
|
|
||||||
ehs=ehs,
|
|
||||||
board_cards=board_cards,
|
|
||||||
player_cards=player_cards
|
|
||||||
)
|
|
||||||
|
|
||||||
records.append(record)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
decode_errors += 1
|
|
||||||
if decode_errors <= 3:
|
|
||||||
print(f"记录 {i} 解码失败: {e}")
|
|
||||||
print(f" board_id={row['board']:016x}, player_id={row['player']:04x}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 显示进度
|
|
||||||
if (i + 1) % 10000 == 0:
|
|
||||||
success_rate = len(records) / (i + 1) * 100
|
|
||||||
print(f" 已处理 {i+1:,}/{len(raw_data):,} 条记录,成功率 {success_rate:.1f}%")
|
|
||||||
|
|
||||||
total_records = len(raw_data)
|
|
||||||
success_records = len(records)
|
|
||||||
success_rate = success_records / total_records * 100
|
|
||||||
|
|
||||||
print(f"river_EHS解析完成:")
|
|
||||||
print(f" 总记录数: {total_records:,}")
|
|
||||||
print(f" 成功解码: {success_records:,} ({success_rate:.1f}%)")
|
|
||||||
print(f" 解码失败: {decode_errors:,}")
|
|
||||||
|
|
||||||
return records
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"文件不存在: {filepath}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
print(f"解析river_EHS数据失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def parse_turn_hist_with_cards(self, filename: str = "turn_hist.npy", max_records: int = 1000) -> List[TurnHistRecord]:
|
|
||||||
filepath = f"{self.data_path}/{filename}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw_data = np.load(filepath)
|
|
||||||
print(f"加载turn_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
|
|
||||||
|
|
||||||
records = []
|
|
||||||
decode_errors = 0
|
|
||||||
# 抽样
|
|
||||||
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
|
||||||
|
|
||||||
for i, row in enumerate(data_to_process):
|
|
||||||
try:
|
|
||||||
board_id = int(row['board'])
|
|
||||||
player_id = int(row['player'])
|
|
||||||
bins = row['bins'].copy()
|
|
||||||
|
|
||||||
# 解码公共牌 (4张)
|
|
||||||
board_cards = self.decoder.decode_board_id(board_id, 4)
|
|
||||||
|
|
||||||
# 解码玩家手牌 (2张)
|
|
||||||
player_cards = self.decoder.decode_player_id(player_id)
|
|
||||||
|
|
||||||
# 验证牌面不重复
|
|
||||||
all_cards = board_cards + player_cards
|
|
||||||
|
|
||||||
unique_cards = self.decoder.decode_board_unique_card(all_cards)
|
|
||||||
if len(unique_cards) != 6:
|
|
||||||
decode_errors += 1
|
|
||||||
print(f"记录 {i}: 存在重复牌面")
|
|
||||||
|
|
||||||
record = TurnHistRecord(
|
|
||||||
board_id=board_id,
|
|
||||||
player_id=player_id,
|
|
||||||
bins=bins, # 存储30个bins的numpy数组
|
|
||||||
board_cards=board_cards,
|
|
||||||
player_cards=player_cards
|
|
||||||
)
|
|
||||||
|
|
||||||
records.append(record)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
decode_errors += 1
|
|
||||||
if decode_errors <= 3:
|
|
||||||
print(f"记录 {i} 解码失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (i + 1) % 100 == 0:
|
|
||||||
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
|
|
||||||
|
|
||||||
print(f"turn_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
|
|
||||||
return records
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" 解析turn_hist数据失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def parse_flop_hist_with_cards(self, filename: str = "flop_hist.npy", max_records: int = 500) -> List[FlopHistRecord]:
|
|
||||||
"""解析flop_hist"""
|
|
||||||
filepath = f"{self.data_path}/{filename}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw_data = np.load(filepath)
|
|
||||||
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
|
|
||||||
records = []
|
|
||||||
decode_errors = 0
|
|
||||||
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
|
||||||
|
|
||||||
for i, row in enumerate(data_to_process):
|
|
||||||
try:
|
|
||||||
board_id = int(row['board'])
|
|
||||||
player_id = int(row['player'])
|
|
||||||
bins = row['bins'].copy()
|
|
||||||
|
|
||||||
# 解码公共牌 (3张)
|
|
||||||
board_cards = self.decoder.decode_board_id(board_id, 3)
|
|
||||||
|
|
||||||
# 解码玩家手牌 (2张)
|
|
||||||
player_cards = self.decoder.decode_player_id(player_id)
|
|
||||||
|
|
||||||
# 验证牌面不重复
|
|
||||||
all_cards = board_cards + player_cards
|
|
||||||
unique_cards = self.decoder.decode_board_unique_card(all_cards)
|
|
||||||
if len(unique_cards) != 5:
|
|
||||||
decode_errors += 1
|
|
||||||
print(f"记录 {i}: 存在重复牌面")
|
|
||||||
|
|
||||||
record = FlopHistRecord(
|
|
||||||
board_id=board_id,
|
|
||||||
player_id=player_id,
|
|
||||||
bins=bins, # 存储465个bins的numpy数组
|
|
||||||
board_cards=board_cards,
|
|
||||||
player_cards=player_cards
|
|
||||||
)
|
|
||||||
|
|
||||||
records.append(record)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
decode_errors += 1
|
|
||||||
if decode_errors <= 3:
|
|
||||||
print(f"记录 {i} 解码失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (i + 1) % 50 == 0:
|
|
||||||
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
|
|
||||||
|
|
||||||
print(f"flop_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
|
|
||||||
return records
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"解析flop_hist数据失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def analyze_parsed_data():
|
|
||||||
|
|
||||||
print("开始解析 xtask 数据并解码牌面...\n")
|
|
||||||
|
|
||||||
try:
|
|
||||||
parser = XTaskDataParser(data_path="ehs_data")
|
|
||||||
|
|
||||||
# 1. 解析river_EHS数据
|
|
||||||
print("=" * 60)
|
|
||||||
print(" 解析river_EHS")
|
|
||||||
print("=" * 60)
|
|
||||||
river_records = parser.parse_river_ehs_with_cards()
|
|
||||||
|
|
||||||
# 显示river_EHS样本
|
|
||||||
print(f"\nriver_EHS数据样本 (前10条):")
|
|
||||||
for i, record in enumerate(river_records[:10]):
|
|
||||||
board_str = " ".join(str(c) for c in record.board_cards)
|
|
||||||
player_str = " ".join(str(c) for c in record.player_cards)
|
|
||||||
print(f" {i+1:2d}. Board:[{board_str:14s}] Player:[{player_str:5s}] EHS:{record.ehs:.4f}")
|
|
||||||
|
|
||||||
# 2. 解析turn_hist数据
|
|
||||||
print(f"\n" + "=" * 60)
|
|
||||||
print(" 解析turn_hist数据")
|
|
||||||
print("=" * 60)
|
|
||||||
turn_records = parser.parse_turn_hist_with_cards(max_records=100) # 限制数量
|
|
||||||
# 显示turn_hist数据样本
|
|
||||||
print(f"\nturn_hist数据样本 (前5条):")
|
|
||||||
for i, record in enumerate(turn_records[:5]):
|
|
||||||
board_str = " ".join(str(c) for c in record.board_cards)
|
|
||||||
player_str = " ".join(str(c) for c in record.player_cards)
|
|
||||||
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
|
|
||||||
print(f" {i+1}. Board:[{board_str:11s}] Player:[{player_str:5s}] Bins:{bins_stats}")
|
|
||||||
|
|
||||||
# 3. 解析flop_hist数据
|
|
||||||
print(f"\n" + "=" * 60)
|
|
||||||
print(" 解析flop_hist数据")
|
|
||||||
print("=" * 60)
|
|
||||||
flop_records = parser.parse_flop_hist_with_cards(max_records=50) # 限制数量
|
|
||||||
|
|
||||||
# 显示flop_数据样本
|
|
||||||
print(f"\n flop_hist数据样本 (前5条):")
|
|
||||||
for i, record in enumerate(flop_records[:5]):
|
|
||||||
board_str = " ".join(str(c) for c in record.board_cards)
|
|
||||||
player_str = " ".join(str(c) for c in record.player_cards)
|
|
||||||
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
|
|
||||||
print(f" {i+1}. Board:[{board_str:8s}] Player:[{player_str:5s}] Bins:{bins_stats}")
|
|
||||||
|
|
||||||
# 4. 统计摘要
|
|
||||||
print(f"\n" + "=" * 60)
|
|
||||||
print("解析统计摘要")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"river_EHS记录: {len(river_records):,} 条")
|
|
||||||
print(f"turn_hist记录: {len(turn_records):,} 条")
|
|
||||||
print(f"flop_hist记录: {len(flop_records):,} 条")
|
|
||||||
|
|
||||||
# EHS 统计
|
|
||||||
if river_records:
|
|
||||||
ehs_values = [r.ehs for r in river_records]
|
|
||||||
print(f"river_EHS统计:")
|
|
||||||
print(f" 范围: [{min(ehs_values):.4f}, {max(ehs_values):.4f}]")
|
|
||||||
print(f" 均值: {np.mean(ehs_values):.4f}")
|
|
||||||
print(f" 标准差: {np.std(ehs_values):.4f}")
|
|
||||||
|
|
||||||
print(f"\n数据解析完成!所有 Board ID 和 Player ID 已成功解码为具体牌面。")
|
|
||||||
|
|
||||||
return {
|
|
||||||
'river_records': river_records,
|
|
||||||
'turn_records': turn_records,
|
|
||||||
'flop_records': flop_records
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" 数据解析失败: {e}")
|
|
||||||
return None
|
|
||||||
216
cross_validation/validator.py
Normal file
216
cross_validation/validator.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from poker import Suit, Card
|
||||||
|
from shortdeck import ShortDeckHandEvaluator as HE
|
||||||
|
from shortdeck import ShortDeckRank as SDR
|
||||||
|
|
||||||
|
data_path = Path(".") / "ehs-data"
|
||||||
|
np_river = np.load(data_path / "river_ehs_sd.npy")
|
||||||
|
np_turn = np.load(data_path / "turn_hist_sd.npy")
|
||||||
|
np_flop = np.load(data_path / "flop_hist_sd.npy")
|
||||||
|
|
||||||
|
cards = [Card(r, s) for r in SDR for s in Suit]
|
||||||
|
|
||||||
|
CARD_BITS = 6
|
||||||
|
class SuitMapping:
|
||||||
|
def __init__(self):
|
||||||
|
self.mapping = {}
|
||||||
|
self.suits = list(reversed(Suit))
|
||||||
|
|
||||||
|
def map_suit(self, s: Suit) -> Suit:
|
||||||
|
if s not in self.mapping:
|
||||||
|
self.mapping[s] = self.suits.pop()
|
||||||
|
|
||||||
|
return self.mapping[s]
|
||||||
|
|
||||||
|
class EhsCache:
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = defaultdict(lambda: defaultdict(dict))
|
||||||
|
|
||||||
|
def _set_keys(self, flop, player):
|
||||||
|
suit_map = SuitMapping()
|
||||||
|
complex_cards = player+flop
|
||||||
|
iso_complex = to_iso(complex_cards, suit_map)
|
||||||
|
complex_key = cards_to_u32(iso_complex)
|
||||||
|
return complex_key
|
||||||
|
# 全部存下来计算耗时太大了。。嗯。。
|
||||||
|
# todo:不存储直接抽样计算
|
||||||
|
def store_river_ehs(self, player, board, ehs):
|
||||||
|
complex_key = self._set_keys(board[:3],player)
|
||||||
|
turn_idx = card_index(board[3])
|
||||||
|
river_idx = card_index(board[4])
|
||||||
|
self.cache[complex_key][turn_idx][river_idx] = ehs
|
||||||
|
|
||||||
|
def get_turn_hist(self, player, flop, turn):
|
||||||
|
complex_key = self._set_keys(flop, player)
|
||||||
|
turn_idx = card_index(turn)
|
||||||
|
turn_hist = self.cache[complex_key][turn_idx]
|
||||||
|
return list(turn_hist.values()) if turn_hist else None
|
||||||
|
|
||||||
|
def get_flop_hist(self, player, flop):
|
||||||
|
complex_key = self._set_keys(flop, player)
|
||||||
|
all_ehs = []
|
||||||
|
player_data = self.cache[complex_key]
|
||||||
|
for turn_idx in player_data:
|
||||||
|
for river_idx in player_data[turn_idx]:
|
||||||
|
all_ehs.append(player_data[turn_idx][river_idx])
|
||||||
|
return all_ehs if len(all_ehs) == 465 else None
|
||||||
|
|
||||||
|
def get_rank_idx(rank: SDR) -> int:
|
||||||
|
rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN,
|
||||||
|
SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE]
|
||||||
|
return rank_order.index(rank)
|
||||||
|
|
||||||
|
def get_suit_idx(suit: Suit) -> int:
|
||||||
|
suit_order = [Suit.SPADES, Suit.HEARTS, Suit.DIAMONDS, Suit.CLUBS]
|
||||||
|
return suit_order.index(suit)
|
||||||
|
|
||||||
|
def card_index(card: Card) -> int:
|
||||||
|
return (get_rank_idx(card.rank) + 4) * 4 + get_suit_idx(card.suit)
|
||||||
|
|
||||||
|
Card.__eq__ = lambda a, b: (a.rank == b.rank) and (a.suit == b.suit)
|
||||||
|
Card.__hash__ = lambda a: hash((get_rank_idx(a.rank), get_suit_idx(a.suit)))
|
||||||
|
|
||||||
|
def cards_to_u32(cards: list[Card]) -> int:
|
||||||
|
res = 0
|
||||||
|
for i, card in enumerate(cards):
|
||||||
|
bits = card_index(card) & 0x3F
|
||||||
|
res |= bits << (i * CARD_BITS)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def to_iso(cards: list[Card], mapping: SuitMapping) -> list[Card]:
|
||||||
|
def count_suit(card: Card) -> int:
|
||||||
|
return sum(1 for other in cards if other.suit == card.suit)
|
||||||
|
|
||||||
|
sorted_cards = sorted(
|
||||||
|
cards,
|
||||||
|
key=lambda c: (count_suit(c), get_rank_idx(c.rank), get_suit_idx(c.suit))
|
||||||
|
)
|
||||||
|
res = []
|
||||||
|
for card in sorted_cards:
|
||||||
|
mapped_suit = mapping.map_suit(card.suit)
|
||||||
|
res.append(Card(card.rank, mapped_suit))
|
||||||
|
|
||||||
|
return sorted(res, key=lambda c: (get_rank_idx(c.rank), get_suit_idx(c.suit)))
|
||||||
|
|
||||||
|
def cards_to_u16(cards: list[Card]) -> int:
|
||||||
|
res = 0
|
||||||
|
for i, card in enumerate(cards):
|
||||||
|
bits = card_index(card) & 0x3F
|
||||||
|
res |= bits << (i * CARD_BITS)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def calc_river_ehs(board: list[Card], player: list[Card]) -> float:
|
||||||
|
player_hand = [*board, *player]
|
||||||
|
player_ranking = HE.evaluate_hand(player_hand)
|
||||||
|
|
||||||
|
acc = 0
|
||||||
|
sum = 0
|
||||||
|
for other in itertools.combinations(cards, 2):
|
||||||
|
if set(other) & set(player_hand):
|
||||||
|
continue
|
||||||
|
if set(other) & set(board):
|
||||||
|
continue
|
||||||
|
other_ranking = HE.evaluate_hand([*board, *other])
|
||||||
|
if player_ranking == other_ranking:
|
||||||
|
acc += 1
|
||||||
|
elif player_ranking > other_ranking:
|
||||||
|
acc += 2
|
||||||
|
sum += 2
|
||||||
|
return acc / sum
|
||||||
|
|
||||||
|
def get_data(board: list[Card], player: list[Card]):
|
||||||
|
def _get_data(data, board: list[Card], player: list[Card]):
|
||||||
|
suit_map = SuitMapping()
|
||||||
|
|
||||||
|
iso_board = to_iso(board, suit_map)
|
||||||
|
iso_player = to_iso(player, suit_map)
|
||||||
|
|
||||||
|
mask_board = data["board"] == cards_to_u32(iso_board)
|
||||||
|
mask_player = data["player"] == cards_to_u16(iso_player)
|
||||||
|
return data[mask_board & mask_player][0][2]
|
||||||
|
match len(board):
|
||||||
|
case 3:
|
||||||
|
return _get_data(np_flop, board, player)
|
||||||
|
case 4:
|
||||||
|
return _get_data(np_turn, board, player)
|
||||||
|
case 5:
|
||||||
|
return _get_data(np_river, board, player)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def euclidean_dist(left, right):
|
||||||
|
if isinstance(left, Iterable):
|
||||||
|
v1 = np.sort(np.array(left, dtype=np.float32))
|
||||||
|
v2 = np.sort(np.array(right, dtype=np.float32))
|
||||||
|
return np.linalg.norm(v2 - v1)
|
||||||
|
else:
|
||||||
|
return np.abs(left - right) ** 2
|
||||||
|
|
||||||
|
def compare_data(sampled, board, player):
|
||||||
|
err_count = 0
|
||||||
|
d = euclidean_dist(get_data(board, player), sampled)
|
||||||
|
if not np.isclose(d, 0.0):
|
||||||
|
print(f"[{''.join(map(str, board))} {''.join(map(str, player))}]: {d}")
|
||||||
|
err_count += 1
|
||||||
|
return err_count
|
||||||
|
|
||||||
|
card_ehs = defaultdict(dict)
|
||||||
|
|
||||||
|
def validate_river():
|
||||||
|
validated_count = 0
|
||||||
|
error_count = 0
|
||||||
|
for river_combo in itertools.combinations(cards, 5):
|
||||||
|
board = list(river_combo)
|
||||||
|
unused_cards = [c for c in cards if c not in board]
|
||||||
|
for player_combo in itertools.combinations(unused_cards, 2):
|
||||||
|
player = list(player_combo)
|
||||||
|
ehs = calc_river_ehs(board, player)
|
||||||
|
ehs_stored.store_river_ehs(player, board, ehs)
|
||||||
|
error_count = compare_data(ehs, board, player)
|
||||||
|
validated_count += 1
|
||||||
|
print(".", end="", flush=True)
|
||||||
|
print(f"river validate count {validated_count}")
|
||||||
|
print(f"Validated river hands: {validated_count}, Errors: {error_count}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_turn():
|
||||||
|
validated_count = 0
|
||||||
|
error_count = 0
|
||||||
|
for turn_combo in itertools.combinations(cards, 4):
|
||||||
|
turn = list(turn_combo)
|
||||||
|
unused_cards = [c for c in cards if c not in turn]
|
||||||
|
for player_combo in itertools.combinations(unused_cards, 2):
|
||||||
|
player = list(player_combo)
|
||||||
|
turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3])
|
||||||
|
error_count += compare_data(turn_hist, turn, player)
|
||||||
|
validated_count += 1
|
||||||
|
print(".", end="", flush=True)
|
||||||
|
print(f"Validated turn hands: {validated_count}, Errors: {error_count}")
|
||||||
|
|
||||||
|
def validate_flop():
|
||||||
|
sample = random.sample(cards, 5)
|
||||||
|
flop = sample[2:]
|
||||||
|
player = sample[:2]
|
||||||
|
|
||||||
|
flop_hist = ehs_stored.get_flop_hist(player, flop)
|
||||||
|
if flop_hist is None:
|
||||||
|
return
|
||||||
|
compare_data(flop_hist, flop, player)
|
||||||
|
|
||||||
|
|
||||||
|
ehs_stored = EhsCache()
|
||||||
|
|
||||||
|
def cross_validate_main():
|
||||||
|
validate_river()
|
||||||
|
validate_turn()
|
||||||
|
validate_flop()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cross_validate_main()
|
||||||
@@ -1,32 +1,8 @@
|
|||||||
from cross_validation import DataValidator
|
from cross_validation import cross_validate_main
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
return cross_validate_main()
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='从xtask导出数据中抽取牌面进行EHS验证',
|
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument('--river-samples', type=int, default=10, help='River样本数 (默认: 10)')
|
|
||||||
parser.add_argument('--turn-samples', type=int, default=5, help='Turn样本数 (默认: 5)')
|
|
||||||
parser.add_argument('--flop-samples', type=int, default=3, help='Flop样本数 (默认: 3)')
|
|
||||||
parser.add_argument('--data-path', type=str, default='ehs_data', help='数据路径 (默认: ehs_data)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
validator = DataValidator(data_path=args.data_path)
|
|
||||||
results = validator.run_full_validation(
|
|
||||||
river_samples=args.river_samples,
|
|
||||||
turn_samples=args.turn_samples,
|
|
||||||
flop_samples=args.flop_samples
|
|
||||||
)
|
|
||||||
|
|
||||||
return 0 if results['overall_success'] else 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
Reference in New Issue
Block a user