task5
This commit is contained in:
8
cross_validation/__init__.py
Normal file
8
cross_validation/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Cross Validation module for EHS winrate data validation
|
||||
"""
|
||||
|
||||
from .cross_validation import DataValidator
|
||||
from .parse_data import XTaskDataParser
|
||||
|
||||
__all__ = ['DataValidator', 'XTaskDataParser']
|
||||
357
cross_validation/cross_validation.py
Normal file
357
cross_validation/cross_validation.py
Normal file
@@ -0,0 +1,357 @@
|
||||
#!/usr/bin/env python3
|
||||
import numpy as np
|
||||
from typing import List, Dict, Tuple
|
||||
from scipy.stats import wasserstein_distance
|
||||
import sys
|
||||
|
||||
from .parse_data import XTaskDataParser
|
||||
from shortdeck.gen_hist import ShortDeckHistGenerator
|
||||
|
||||
|
||||
class DataValidator:
|
||||
|
||||
|
||||
def __init__(self, data_path: str = "ehs_data"):
|
||||
self.parser = XTaskDataParser(data_path)
|
||||
self.generator = ShortDeckHistGenerator()
|
||||
print(" DataValidator初始化完成")
|
||||
print(f" 生成器短牌型大小: {len(self.generator.full_deck)}")
|
||||
|
||||
def validate_river_samples(self, max_samples: int = 20) :
|
||||
# print(f"\n 验证river_EHS样本 (最大样本数: {max_samples})")
|
||||
|
||||
try:
|
||||
print(" 解析导出的river数据...")
|
||||
print('='*60)
|
||||
river_records = self.parser.parse_river_ehs_with_cards()
|
||||
|
||||
if not river_records:
|
||||
return {'error': '没有解析到river记录', 'success': False}
|
||||
|
||||
|
||||
sample_records = np.random.choice(river_records, size=max_samples, replace=False)
|
||||
print(f" 选择 {len(sample_records)} 个样本进行验证")
|
||||
|
||||
matches = 0
|
||||
errors = 0
|
||||
differences = []
|
||||
|
||||
for i, record in enumerate(sample_records):
|
||||
try:
|
||||
player_cards = record.player_cards
|
||||
board_cards = record.board_cards
|
||||
src_river_ehs = record.ehs
|
||||
|
||||
|
||||
cur_river_ehs = self.generator.generate_river_ehs(player_cards, board_cards)
|
||||
|
||||
|
||||
ehs_difference = abs(src_river_ehs - cur_river_ehs)
|
||||
|
||||
|
||||
player_str = " ".join(str(c) for c in player_cards)
|
||||
board_str = " ".join(str(c) for c in board_cards)
|
||||
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
||||
print(f" 原始EHS: {src_river_ehs:.6f}")
|
||||
print(f" 重算EHS: {cur_river_ehs:.6f}")
|
||||
print(f" 差异: {ehs_difference:.6f}")
|
||||
|
||||
# 判断匹配 (允许小的数值差异)
|
||||
tolerance = 1e-6
|
||||
if ehs_difference < tolerance:
|
||||
matches += 1
|
||||
else:
|
||||
differences.append(ehs_difference)
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors <= 3:
|
||||
print(f" 样本 {i+1} 计算失败: {e}")
|
||||
|
||||
# 统计结果
|
||||
total_samples = len(sample_records)
|
||||
match_rate = matches / total_samples if total_samples > 0 else 0
|
||||
mean_diff = np.mean(differences) if differences else 0
|
||||
max_diff = np.max(differences) if differences else 0
|
||||
|
||||
result = {
|
||||
'total_samples': total_samples,
|
||||
'matches': matches,
|
||||
'match_rate': match_rate,
|
||||
'mean_difference': mean_diff,
|
||||
'max_difference': max_diff,
|
||||
'errors': errors,
|
||||
'success': match_rate > 0.8 and mean_diff < 0.05
|
||||
}
|
||||
|
||||
print(f" River验证完成:")
|
||||
print('='*60)
|
||||
print(f" 匹配数: {matches}/{total_samples} ({match_rate:.1%})")
|
||||
print(f" 平均差异: {mean_diff:.6f}")
|
||||
print(f" 最大差异: {max_diff:.6f}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f" River验证失败: {e}")
|
||||
return {'error': str(e), 'success': False}
|
||||
|
||||
def validate_turn_samples(self, max_samples: int = 10):
|
||||
|
||||
# print(f"\n 验证turn_HIST样本 (最大样本数: {max_samples})")
|
||||
|
||||
try:
|
||||
|
||||
print(" 解析导出的Turn数据...")
|
||||
print('='*60)
|
||||
print('='*60)
|
||||
turn_records = self.parser.parse_turn_hist_with_cards(max_records=max_samples)
|
||||
|
||||
if not turn_records:
|
||||
return {'error': '没有解析到Turn记录', 'success': False}
|
||||
|
||||
print(f" 解析到 {len(turn_records)} 个Turn样本")
|
||||
|
||||
|
||||
low_emd_count = 0
|
||||
emd_distances = []
|
||||
errors = 0
|
||||
|
||||
for i, record in enumerate(turn_records):
|
||||
try:
|
||||
|
||||
player_cards = record.player_cards
|
||||
board_cards = record.board_cards
|
||||
src_hist = np.array(record.bins)
|
||||
|
||||
cur_hist = self.generator.generate_turn_histogram(
|
||||
player_cards, board_cards, num_bins=len(src_hist)
|
||||
)
|
||||
cur_hist = np.array(cur_hist)
|
||||
|
||||
# 归一化
|
||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
||||
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
||||
|
||||
# 计算EMD距离
|
||||
emd_dist = wasserstein_distance(
|
||||
range(len(src_hist_norm)),
|
||||
range(len(cur_hist_norm)),
|
||||
src_hist_norm,
|
||||
cur_hist_norm
|
||||
)
|
||||
|
||||
is_low_emd = emd_dist < 0.2
|
||||
if is_low_emd:
|
||||
low_emd_count += 1
|
||||
|
||||
# 显示详细信息(前3个样本)
|
||||
if i < 3:
|
||||
player_str = " ".join(str(c) for c in player_cards)
|
||||
board_str = " ".join(str(c) for c in board_cards)
|
||||
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
||||
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
||||
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
||||
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors <= 3:
|
||||
print(f" 样本 {i+1} 计算失败: {e}")
|
||||
|
||||
# 统计结果
|
||||
total_samples = len(turn_records)
|
||||
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
|
||||
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
|
||||
|
||||
result = {
|
||||
'total_samples': total_samples,
|
||||
'low_emd_count': low_emd_count,
|
||||
'low_emd_rate': low_emd_rate,
|
||||
'mean_emd_distance': mean_emd,
|
||||
'emd_distances': emd_distances,
|
||||
'errors': errors,
|
||||
'success': low_emd_rate > 0.6 and mean_emd < 0.5
|
||||
}
|
||||
|
||||
print(f" Turn验证完成:")
|
||||
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
|
||||
print(f" 平均EMD: {mean_emd:.6f}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f" Turn验证失败: {e}")
|
||||
return {'error': str(e), 'success': False}
|
||||
|
||||
def validate_flop_samples(self, max_samples: int = 5):
|
||||
# print(f"\n 验证Flop直方图样本 (最大样本数: {max_samples})")
|
||||
|
||||
try:
|
||||
|
||||
print(" 解析导出Flop数据...")
|
||||
print('='*60)
|
||||
print('='*60)
|
||||
print('='*60)
|
||||
flop_records = self.parser.parse_flop_hist_with_cards(max_records=max_samples)
|
||||
|
||||
if not flop_records:
|
||||
return {'error': '没有解析到Flop记录', 'success': False}
|
||||
|
||||
print(f" 解析到 {len(flop_records)} 个Flop样本")
|
||||
|
||||
low_emd_count = 0
|
||||
emd_distances = []
|
||||
errors = 0
|
||||
|
||||
for i, record in enumerate(flop_records):
|
||||
try:
|
||||
print(f" 处理样本 {i+1}/{len(flop_records)}...")
|
||||
|
||||
player_cards = record.player_cards
|
||||
board_cards = record.board_cards
|
||||
src_hist = np.array(record.bins)
|
||||
|
||||
cur_hist = self.generator.generate_flop_histogram(
|
||||
player_cards, board_cards, num_bins=len(src_hist)
|
||||
)
|
||||
cur_hist = np.array(cur_hist)
|
||||
|
||||
|
||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
||||
cur_hist_norm = cur_hist / cur_hist.sum() if cur_hist.sum() > 0 else cur_hist
|
||||
|
||||
# 计算EMD距离
|
||||
emd_dist = wasserstein_distance(
|
||||
range(len(src_hist_norm)),
|
||||
range(len(cur_hist_norm)),
|
||||
src_hist_norm,
|
||||
cur_hist_norm
|
||||
)
|
||||
# emd_distances.append(emd_dist)
|
||||
|
||||
# is_low_emd = emd_dist < 0.2 # EMD阈值
|
||||
# if is_low_emd:
|
||||
# low_emd_count += 1
|
||||
|
||||
# 显示详细信息
|
||||
player_str = " ".join(str(c) for c in player_cards)
|
||||
board_str = " ".join(str(c) for c in board_cards)
|
||||
print(f" 样本 {i+1}: [{player_str}] + [{board_str}]")
|
||||
print(f" 原始直方图: bins={len(src_hist)}, sum={src_hist.sum():.3f}, 非零bins={np.count_nonzero(src_hist)}")
|
||||
print(f" 生成直方图: bins={len(cur_hist)}, sum={cur_hist.sum():.3f}, 非零bins={np.count_nonzero(cur_hist)}")
|
||||
print(f" 归一化后EMD距离: {emd_dist:.6f}")
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors <= 3:
|
||||
print(f" 样本 {i+1} 计算失败: {e}")
|
||||
|
||||
# 统计结果
|
||||
total_samples = len(flop_records)
|
||||
low_emd_rate = low_emd_count / total_samples if total_samples > 0 else 0
|
||||
mean_emd = np.mean(emd_distances) if emd_distances else float('inf')
|
||||
|
||||
result = {
|
||||
'total_samples': total_samples,
|
||||
'low_emd_count': low_emd_count,
|
||||
'low_emd_rate': low_emd_rate,
|
||||
'mean_emd_distance': mean_emd,
|
||||
'emd_distances': emd_distances,
|
||||
'errors': errors,
|
||||
'success': low_emd_rate > 0.4 and mean_emd < 0.5
|
||||
}
|
||||
|
||||
print(f" Flop验证完成:")
|
||||
print(f" 低EMD数: {low_emd_count}/{total_samples} ({low_emd_rate:.1%})")
|
||||
print(f" 平均EMD: {mean_emd:.6f}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f" Flop验证失败: {e}")
|
||||
return {'error': str(e), 'success': False}
|
||||
|
||||
def run_full_validation(self, river_samples: int = 20, turn_samples: int = 10, flop_samples: int = 5) -> Dict:
|
||||
|
||||
print(" 导出数据EHS验证")
|
||||
print("*"*60)
|
||||
print("验证策略: 从xtask导出数据中抽取牌面 → 短牌型生成器重计算 → 比较一致性")
|
||||
|
||||
# 执行各阶段验证
|
||||
results = {}
|
||||
results['river'] = self.validate_river_samples(river_samples)
|
||||
results['turn'] = self.validate_turn_samples(turn_samples)
|
||||
results['flop'] = self.validate_flop_samples(flop_samples)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(" 验证完毕")
|
||||
print(f"{'='*60}")
|
||||
|
||||
passed_stages = 0
|
||||
total_stages = 3
|
||||
|
||||
# River结果
|
||||
print(f"\n RIVER阶段:")
|
||||
if 'error' not in results['river']:
|
||||
status = " 通过" if results['river']['success'] else " 失败"
|
||||
print(f" 验证结果: {status}")
|
||||
print(f" 样本数量: {results['river']['total_samples']}")
|
||||
print(f" 匹配率: {results['river']['match_rate']:.1%}")
|
||||
print(f" 平均差异: {results['river']['mean_difference']:.6f}")
|
||||
if results['river']['success']:
|
||||
passed_stages += 1
|
||||
else:
|
||||
print(f" 错误: {results['river']['error']}")
|
||||
|
||||
# Turn结果
|
||||
print(f"\n TURN阶段:")
|
||||
if 'error' not in results['turn']:
|
||||
status = " 通过" if results['turn']['success'] else " 失败"
|
||||
print(f" 验证结果: {status}")
|
||||
print(f" 样本数量: {results['turn']['total_samples']}")
|
||||
print(f" 低EMD率: {results['turn']['low_emd_rate']:.1%}")
|
||||
print(f" 平均EMD: {results['turn']['mean_emd_distance']:.6f}")
|
||||
if results['turn']['success']:
|
||||
passed_stages += 1
|
||||
else:
|
||||
print(f" 错误: {results['turn']['error']}")
|
||||
|
||||
# Flop结果
|
||||
print(f"\n FLOP阶段:")
|
||||
if 'error' not in results['flop']:
|
||||
status = " 通过" if results['flop']['success'] else " 失败"
|
||||
print(f" 验证结果: {status}")
|
||||
print(f" 样本数量: {results['flop']['total_samples']}")
|
||||
print(f" 低EMD率: {results['flop']['low_emd_rate']:.1%}")
|
||||
print(f" 平均EMD: {results['flop']['mean_emd_distance']:.6f}")
|
||||
if results['flop']['success']:
|
||||
passed_stages += 1
|
||||
else:
|
||||
print(f" 错误: {results['flop']['error']}")
|
||||
|
||||
# 总体结果
|
||||
passed_stages = 0
|
||||
total_stages = 3
|
||||
|
||||
if results.get('river') and results['river'].get('success', False):
|
||||
passed_stages += 1
|
||||
if results.get('turn') and results['turn'].get('success', False):
|
||||
passed_stages += 1
|
||||
if results.get('flop') and results['flop'].get('success', False):
|
||||
passed_stages += 1
|
||||
|
||||
overall_rate = passed_stages / total_stages
|
||||
# print(f"\n 总体验证通过率: {passed_stages}/{total_stages} ({overall_rate:.1%})")
|
||||
|
||||
# if overall_rate >= 0.7:
|
||||
# print(" 数据验证成功!短牌型生成器与解析数据基本一致。")
|
||||
# else:
|
||||
# print(" 验证存在问题,生成器可能与实际数据不匹配,需要调试。")
|
||||
|
||||
return {
|
||||
'results': results,
|
||||
'passed_stages': passed_stages,
|
||||
'total_stages': total_stages,
|
||||
'overall_success': overall_rate >= 0.7
|
||||
}
|
||||
52
cross_validation/debug_emd.py
Normal file
52
cross_validation/debug_emd.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
调试
|
||||
turn/flop阶段EMD在导出的数据与生成的数据间的差异
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from cross_validation import DataValidator
|
||||
|
||||
validator = DataValidator()
|
||||
|
||||
print("解析Turn样本...")
|
||||
turn_records = validator.parser.parse_turn_hist_with_cards(max_records=1)
|
||||
|
||||
if turn_records:
|
||||
record = turn_records[0]
|
||||
player_cards = record.player_cards
|
||||
board_cards = record.board_cards
|
||||
src_hist = record.bins
|
||||
|
||||
cur_hist = validator.generator.generate_turn_histogram(
|
||||
player_cards, board_cards, num_bins=len(src_hist)
|
||||
)
|
||||
|
||||
|
||||
src_hist_norm = src_hist / src_hist.sum() if src_hist.sum() > 0 else src_hist
|
||||
|
||||
print(f"\n牌面: {[str(c) for c in player_cards]} + {[str(c) for c in board_cards]}")
|
||||
print(f"bin数量: {len(src_hist)} vs {len(cur_hist)}")
|
||||
print(f"原始直方图 - 和: {src_hist.sum():.3f}, 归一化后: {src_hist_norm.sum():.3f}")
|
||||
print(f"生成直方图 - 和: {sum(cur_hist):.3f}")
|
||||
|
||||
print("\n前10个bin对比:")
|
||||
print("Bin 原始值 归一化 生成值")
|
||||
for i in range(min(10, len(src_hist))):
|
||||
print(f"{i:3d} {src_hist[i]:8.3f} {src_hist_norm[i]:8.3f} {cur_hist[i]:8.3f}")
|
||||
|
||||
# 查看非零bin的分布
|
||||
src_nonzero = np.nonzero(src_hist_norm)[0]
|
||||
cur_nonzero = np.nonzero(cur_hist)[0]
|
||||
print(f"\n非零bins位置:")
|
||||
print(f"原始: {src_nonzero[:10]}...")
|
||||
print(f"生成: {cur_nonzero[:10]}...")
|
||||
|
||||
# 计算分布的统计特征
|
||||
src_mean = np.average(range(len(src_hist_norm)), weights=src_hist_norm)
|
||||
cur_mean = np.average(range(len(cur_hist)), weights=cur_hist)
|
||||
|
||||
print(f"\n分布特征:")
|
||||
print(f"原始分布重心: {src_mean:.2f}")
|
||||
print(f"生成分布重心: {cur_mean:.2f}")
|
||||
print(f"重心差异: {abs(src_mean - cur_mean):.2f}")
|
||||
429
cross_validation/parse_data.py
Normal file
429
cross_validation/parse_data.py
Normal file
@@ -0,0 +1,429 @@
|
||||
import numpy as np
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from poker.card import Card, ShortDeckRank, Suit
|
||||
|
||||
@dataclass
|
||||
class RiverEHSRecord:
|
||||
"""river_EHS """
|
||||
board_id: int
|
||||
player_id: int
|
||||
ehs: float
|
||||
board_cards: List[Card] # 5张公共牌
|
||||
player_cards: List[Card] # 2张手牌
|
||||
|
||||
@dataclass
|
||||
class TurnHistRecord:
|
||||
"""turn_hist"""
|
||||
board_id: int
|
||||
player_id: int
|
||||
bins: np.ndarray
|
||||
board_cards: List[Card] # 4张公共牌
|
||||
player_cards: List[Card] # 2张手牌
|
||||
|
||||
@dataclass
|
||||
class FlopHistRecord:
|
||||
"""flop_hist"""
|
||||
board_id: int
|
||||
player_id: int
|
||||
bins: np.ndarray
|
||||
board_cards: List[Card] # 3张公共牌
|
||||
player_cards: List[Card] # 2张手牌
|
||||
|
||||
class OpenPQLDecoder:
|
||||
"""open-pql 完整解码"""
|
||||
|
||||
def __init__(self):
|
||||
# Card64
|
||||
self.OFFSET_SUIT = 16 # 每个suit占16位
|
||||
self.OFFSET_S = 0 # S: [15:0]
|
||||
self.OFFSET_H = 16 # H: [31:16]
|
||||
self.OFFSET_D = 32 # D: [47:32]
|
||||
self.OFFSET_C = 48 # C: [63:48]
|
||||
|
||||
# open-pql 到 短牌型(6-A,36张牌)
|
||||
self.opql_to_poker_rank = {
|
||||
# 0: Rank.TWO,
|
||||
# 1: Rank.THREE,
|
||||
# 2: Rank.FOUR,
|
||||
# 3: Rank.FIVE,
|
||||
4: ShortDeckRank.SIX,
|
||||
5: ShortDeckRank.SEVEN,
|
||||
6: ShortDeckRank.EIGHT,
|
||||
7: ShortDeckRank.NINE,
|
||||
8: ShortDeckRank.TEN,
|
||||
9: ShortDeckRank.JACK,
|
||||
10: ShortDeckRank.QUEEN,
|
||||
11: ShortDeckRank.KING,
|
||||
12: ShortDeckRank.ACE,
|
||||
}
|
||||
|
||||
self.opql_to_poker_suit = {
|
||||
0: Suit.SPADES, # S -> SPADES
|
||||
1: Suit.HEARTS, # H -> HEARTS
|
||||
2: Suit.DIAMONDS, # D -> DIAMONDS
|
||||
3: Suit.CLUBS, # C -> CLUBS
|
||||
}
|
||||
|
||||
print("========== OpenPQLDecoder (短牌型36张牌) ===============")
|
||||
|
||||
def u64_from_ranksuit(self, rank, suit) -> int:
|
||||
"""对应 Card64::u64_from_ranksuit_i8"""
|
||||
return 1 << rank << (suit * self.OFFSET_SUIT)
|
||||
|
||||
def decode_card64(self, card64_u64) -> List[Card]:
|
||||
"""从 Card64 的 u64 表示解码出所有牌"""
|
||||
cards = []
|
||||
|
||||
# 按照 Card64 解析每个suit
|
||||
# 每个suit16位
|
||||
for opql_suit in range(4): # 0,1,2,3
|
||||
suit_offset = opql_suit * self.OFFSET_SUIT
|
||||
suit_bits = (card64_u64 >> suit_offset) & 0xFFFF
|
||||
|
||||
# 检查这个suit的每个rank
|
||||
for opql_rank in range(13): # 0-12
|
||||
if suit_bits & (1 << opql_rank):
|
||||
poker_rank = self.opql_to_poker_rank[opql_rank]
|
||||
poker_suit = self.opql_to_poker_suit[opql_suit]
|
||||
cards.append(Card(poker_rank, poker_suit))
|
||||
|
||||
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
|
||||
return cards
|
||||
|
||||
def decode_hand_n_2(self, id_u16) -> List[Card]:
|
||||
"""解码 Hand<2> 的 u16 ID"""
|
||||
card0_u8 = id_u16 & 0xFF # 低8位
|
||||
card1_u8 = (id_u16 >> 8) & 0xFF # 高8位
|
||||
|
||||
try:
|
||||
card0 = self._card_from_u8(card0_u8)
|
||||
card1 = self._card_from_u8(card1_u8)
|
||||
|
||||
cards = [card0, card1]
|
||||
cards.sort(key=lambda c: (c.rank.numeric_value, c.suit.value))
|
||||
|
||||
return cards
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"解码 Hand<2> 失败: id_u16={id_u16:04x}, card0_u8={card0_u8:02x}, card1_u8={card1_u8:02x}: {e}")
|
||||
|
||||
def _card_from_u8(self, v) -> Card:
|
||||
"""
|
||||
对应 open-pql Card::from_u8()
|
||||
单张牌decode
|
||||
"""
|
||||
SHIFT_SUIT = 4
|
||||
opql_rank = v & 0b1111 # 低4位
|
||||
opql_suit = v >> SHIFT_SUIT # 高4位
|
||||
|
||||
if opql_rank not in self.opql_to_poker_rank:
|
||||
raise ValueError(f"无效的rank: {opql_rank}")
|
||||
if opql_suit not in self.opql_to_poker_suit:
|
||||
raise ValueError(f"无效的suit: {opql_suit}")
|
||||
|
||||
poker_rank = self.opql_to_poker_rank[opql_rank]
|
||||
poker_suit = self.opql_to_poker_suit[opql_suit]
|
||||
|
||||
return Card(poker_rank, poker_suit)
|
||||
|
||||
def decode_board_id(self, board_id, num_cards) -> List[Card]:
|
||||
"""解码公共牌 ID"""
|
||||
if num_cards == 2:
|
||||
# 2张牌使用 Hand<2> 的编码方式
|
||||
return self.decode_hand_n_2(board_id)
|
||||
else:
|
||||
# 3, 4, 5张公共牌使用 Card64 的编码方式
|
||||
cards = self.decode_card64(board_id)
|
||||
|
||||
if len(cards) != num_cards:
|
||||
raise ValueError(f"解码出 {len(cards)} 张牌,期望 {num_cards} 张,board_id={board_id:016x}")
|
||||
|
||||
return cards
|
||||
|
||||
def decode_player_id(self, player_id) -> List[Card]:
|
||||
"""解码玩家手牌 ID (2张牌)"""
|
||||
return self.decode_hand_n_2(player_id)
|
||||
|
||||
def decode_board_unique_card(self, all_cards) -> List[Card]:
|
||||
"""验证并返回不重复的牌面"""
|
||||
unique_cards = []
|
||||
for card in all_cards:
|
||||
is_duplicate = False
|
||||
for existing_card in unique_cards:
|
||||
if card.rank == existing_card.rank and card.suit == existing_card.suit:
|
||||
is_duplicate = True
|
||||
break
|
||||
if not is_duplicate:
|
||||
unique_cards.append(card)
|
||||
return unique_cards
|
||||
|
||||
|
||||
class XTaskDataParser:
|
||||
def __init__(self, data_path: str = "ehs_data"):
|
||||
self.data_path = data_path
|
||||
self.decoder = OpenPQLDecoder()
|
||||
|
||||
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]:
|
||||
filepath = f"{self.data_path}/{filename}"
|
||||
|
||||
try:
|
||||
raw_data = np.load(filepath)
|
||||
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
|
||||
|
||||
records = []
|
||||
decode_errors = 0
|
||||
|
||||
for i, row in enumerate(raw_data):
|
||||
try:
|
||||
board_id = int(row['board'])
|
||||
player_id = int(row['player'])
|
||||
ehs = float(row['ehs'])
|
||||
|
||||
# 解码公共牌 (5张)
|
||||
board_cards = self.decoder.decode_board_id(board_id, 5)
|
||||
|
||||
# 解码玩家手牌 (2张)
|
||||
player_cards = self.decoder.decode_player_id(player_id)
|
||||
|
||||
# 验证牌面不重复
|
||||
all_cards = board_cards + player_cards
|
||||
|
||||
unique_cards = self.decoder.decode_board_unique_card(all_cards)
|
||||
if len(unique_cards) != 7:
|
||||
print(f"记录 {i}: 存在重复牌面")
|
||||
print(f" Board: {[str(c) for c in board_cards]}")
|
||||
print(f" Player: {[str(c) for c in player_cards]}")
|
||||
|
||||
# 创建完整记录
|
||||
record = RiverEHSRecord(
|
||||
board_id=board_id,
|
||||
player_id=player_id,
|
||||
ehs=ehs,
|
||||
board_cards=board_cards,
|
||||
player_cards=player_cards
|
||||
)
|
||||
|
||||
records.append(record)
|
||||
|
||||
except Exception as e:
|
||||
decode_errors += 1
|
||||
if decode_errors <= 3:
|
||||
print(f"记录 {i} 解码失败: {e}")
|
||||
print(f" board_id={row['board']:016x}, player_id={row['player']:04x}")
|
||||
continue
|
||||
|
||||
# 显示进度
|
||||
if (i + 1) % 10000 == 0:
|
||||
success_rate = len(records) / (i + 1) * 100
|
||||
print(f" 已处理 {i+1:,}/{len(raw_data):,} 条记录,成功率 {success_rate:.1f}%")
|
||||
|
||||
total_records = len(raw_data)
|
||||
success_records = len(records)
|
||||
success_rate = success_records / total_records * 100
|
||||
|
||||
print(f"river_EHS解析完成:")
|
||||
print(f" 总记录数: {total_records:,}")
|
||||
print(f" 成功解码: {success_records:,} ({success_rate:.1f}%)")
|
||||
print(f" 解码失败: {decode_errors:,}")
|
||||
|
||||
return records
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"文件不存在: {filepath}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"解析river_EHS数据失败: {e}")
|
||||
raise
|
||||
|
||||
def parse_turn_hist_with_cards(self, filename: str = "turn_hist.npy", max_records: int = 1000) -> List[TurnHistRecord]:
|
||||
filepath = f"{self.data_path}/{filename}"
|
||||
|
||||
try:
|
||||
raw_data = np.load(filepath)
|
||||
print(f"加载turn_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
|
||||
|
||||
records = []
|
||||
decode_errors = 0
|
||||
|
||||
# 限制处理数量以避免内存问题
|
||||
# todo:抽样优化
|
||||
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
|
||||
|
||||
for i, row in enumerate(data_to_process):
|
||||
try:
|
||||
board_id = int(row['board'])
|
||||
player_id = int(row['player'])
|
||||
bins = row['bins'].copy()
|
||||
|
||||
# 解码公共牌 (4张)
|
||||
board_cards = self.decoder.decode_board_id(board_id, 4)
|
||||
|
||||
# 解码玩家手牌 (2张)
|
||||
player_cards = self.decoder.decode_player_id(player_id)
|
||||
|
||||
# 验证牌面不重复
|
||||
all_cards = board_cards + player_cards
|
||||
|
||||
unique_cards = self.decoder.decode_board_unique_card(all_cards)
|
||||
if len(unique_cards) != 6:
|
||||
decode_errors += 1
|
||||
print(f"记录 {i}: 存在重复牌面")
|
||||
|
||||
record = TurnHistRecord(
|
||||
board_id=board_id,
|
||||
player_id=player_id,
|
||||
bins=bins, # 存储30个bins的numpy数组
|
||||
board_cards=board_cards,
|
||||
player_cards=player_cards
|
||||
)
|
||||
|
||||
records.append(record)
|
||||
|
||||
except Exception as e:
|
||||
decode_errors += 1
|
||||
if decode_errors <= 3:
|
||||
print(f"记录 {i} 解码失败: {e}")
|
||||
continue
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
|
||||
|
||||
print(f"turn_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
print(f" 解析turn_hist数据失败: {e}")
|
||||
raise
|
||||
|
||||
def parse_flop_hist_with_cards(self, filename: str = "flop_hist.npy", max_records: int = 500) -> List[FlopHistRecord]:
|
||||
"""解析flop_hist"""
|
||||
filepath = f"{self.data_path}/{filename}"
|
||||
|
||||
try:
|
||||
raw_data = np.load(filepath)
|
||||
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
|
||||
records = []
|
||||
decode_errors = 0
|
||||
|
||||
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
|
||||
|
||||
for i, row in enumerate(data_to_process):
|
||||
try:
|
||||
board_id = int(row['board'])
|
||||
player_id = int(row['player'])
|
||||
bins = row['bins'].copy()
|
||||
|
||||
# 解码公共牌 (3张)
|
||||
board_cards = self.decoder.decode_board_id(board_id, 3)
|
||||
|
||||
# 解码玩家手牌 (2张)
|
||||
player_cards = self.decoder.decode_player_id(player_id)
|
||||
|
||||
# 验证牌面不重复
|
||||
all_cards = board_cards + player_cards
|
||||
unique_cards = self.decoder.decode_board_unique_card(all_cards)
|
||||
if len(unique_cards) != 5:
|
||||
decode_errors += 1
|
||||
print(f"记录 {i}: 存在重复牌面")
|
||||
|
||||
record = FlopHistRecord(
|
||||
board_id=board_id,
|
||||
player_id=player_id,
|
||||
bins=bins, # 存储465个bins的numpy数组
|
||||
board_cards=board_cards,
|
||||
player_cards=player_cards
|
||||
)
|
||||
|
||||
records.append(record)
|
||||
|
||||
except Exception as e:
|
||||
decode_errors += 1
|
||||
if decode_errors <= 3:
|
||||
print(f"记录 {i} 解码失败: {e}")
|
||||
continue
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f" 已处理 {i+1}/{len(data_to_process)} 条记录...")
|
||||
|
||||
print(f"flop_hist解析完成: 成功 {len(records)}/{len(data_to_process)} 条记录")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析flop_hist数据失败: {e}")
|
||||
raise
|
||||
|
||||
def analyze_parsed_data():
|
||||
|
||||
print("开始解析 xtask 数据并解码牌面...\n")
|
||||
|
||||
try:
|
||||
parser = XTaskDataParser(data_path="ehs_data")
|
||||
|
||||
# 1. 解析river_EHS数据
|
||||
print("=" * 60)
|
||||
print(" 解析river_EHS")
|
||||
print("=" * 60)
|
||||
river_records = parser.parse_river_ehs_with_cards()
|
||||
|
||||
# 显示river_EHS样本
|
||||
print(f"\nriver_EHS数据样本 (前10条):")
|
||||
for i, record in enumerate(river_records[:10]):
|
||||
board_str = " ".join(str(c) for c in record.board_cards)
|
||||
player_str = " ".join(str(c) for c in record.player_cards)
|
||||
print(f" {i+1:2d}. Board:[{board_str:14s}] Player:[{player_str:5s}] EHS:{record.ehs:.4f}")
|
||||
|
||||
# 2. 解析turn_hist数据
|
||||
print(f"\n" + "=" * 60)
|
||||
print(" 解析turn_hist数据")
|
||||
print("=" * 60)
|
||||
turn_records = parser.parse_turn_hist_with_cards(max_records=100) # 限制数量
|
||||
# 显示turn_hist数据样本
|
||||
print(f"\nturn_hist数据样本 (前5条):")
|
||||
for i, record in enumerate(turn_records[:5]):
|
||||
board_str = " ".join(str(c) for c in record.board_cards)
|
||||
player_str = " ".join(str(c) for c in record.player_cards)
|
||||
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
|
||||
print(f" {i+1}. Board:[{board_str:11s}] Player:[{player_str:5s}] Bins:{bins_stats}")
|
||||
|
||||
# 3. 解析flop_hist数据
|
||||
print(f"\n" + "=" * 60)
|
||||
print(" 解析flop_hist数据")
|
||||
print("=" * 60)
|
||||
flop_records = parser.parse_flop_hist_with_cards(max_records=50) # 限制数量
|
||||
|
||||
# 显示flop_数据样本
|
||||
print(f"\n flop_hist数据样本 (前5条):")
|
||||
for i, record in enumerate(flop_records[:5]):
|
||||
board_str = " ".join(str(c) for c in record.board_cards)
|
||||
player_str = " ".join(str(c) for c in record.player_cards)
|
||||
bins_stats = f"mean={record.bins.mean():.3f}, std={record.bins.std():.3f}"
|
||||
print(f" {i+1}. Board:[{board_str:8s}] Player:[{player_str:5s}] Bins:{bins_stats}")
|
||||
|
||||
# 4. 统计摘要
|
||||
print(f"\n" + "=" * 60)
|
||||
print("解析统计摘要")
|
||||
print("=" * 60)
|
||||
print(f"river_EHS记录: {len(river_records):,} 条")
|
||||
print(f"turn_hist记录: {len(turn_records):,} 条")
|
||||
print(f"flop_hist记录: {len(flop_records):,} 条")
|
||||
|
||||
# EHS 统计
|
||||
if river_records:
|
||||
ehs_values = [r.ehs for r in river_records]
|
||||
print(f"river_EHS统计:")
|
||||
print(f" 范围: [{min(ehs_values):.4f}, {max(ehs_values):.4f}]")
|
||||
print(f" 均值: {np.mean(ehs_values):.4f}")
|
||||
print(f" 标准差: {np.std(ehs_values):.4f}")
|
||||
|
||||
print(f"\n数据解析完成!所有 Board ID 和 Player ID 已成功解码为具体牌面。")
|
||||
|
||||
return {
|
||||
'river_records': river_records,
|
||||
'turn_records': turn_records,
|
||||
'flop_records': flop_records
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f" 数据解析失败: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user