Files
poker_task1/shortdeck/gen_hist.py
2025-09-26 17:08:27 +08:00

244 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
import numpy as np
from typing import List, Dict, Tuple
import itertools
from poker.card import Card, ShortDeckRank, Suit
from shortdeck.hand_evaluator import HandEvaluator
class ShortDeckHistGenerator:
def __init__(self):
# 36张牌6-A
self.full_deck = []
for rank in ShortDeckRank:
for suit in Suit:
self.full_deck.append(Card(rank, suit))
self.hand_evaluator = HandEvaluator()
print(f"初始化短牌型EHS直方图生成器牌组大小: {len(self.full_deck)}")
print(f"牌型范围: {ShortDeckRank.SIX.name}-{ShortDeckRank.ACE.name}")
def generate_river_ehs(self, player_cards, board_cards) -> float:
"""
River阶段胜率计算确定性结果
5张公共牌已知直接计算对所有可能对手的胜率
player_cards: 玩家的2张底牌
board_cards: 5张公共牌
Returns:
胜率值 (0-1之间)
"""
if len(player_cards) != 2:
raise ValueError("玩家必须有2张底牌")
if len(board_cards) != 5:
raise ValueError("River阶段必须有5张公共牌")
# 计算玩家的7张牌牌力
player_7_cards = player_cards + board_cards
player_strength = self.hand_evaluator.evaluate_hand(player_7_cards)
# 获取剩余的牌
used_cards = player_cards + board_cards
remaining_cards = [card for card in self.full_deck if card not in used_cards]
# 计算对所有可能对手的胜率
wins = 0
total_player2s = 0
# 遍历所有可能的对手底牌组合
for player2_cards in itertools.combinations(remaining_cards, 2):
player2_7_cards = list(player2_cards) + board_cards
player2_strength = self.hand_evaluator.evaluate_hand(player2_7_cards)
if player_strength > player2_strength:
wins += 2
elif player_strength == player2_strength:
wins += 1
total_player2s += 2
ehs = wins / total_player2s
return ehs
def generate_turn_histogram(self, player_cards, board_cards, num_bins: int = 30) -> np.ndarray:
"""
player_cards: 玩家的2张底牌
board_cards: 4张公共牌
num_bins: 30个bin短牌型应该是30
"""
if len(player_cards) != 2:
raise ValueError("玩家必须有2张底牌")
if len(board_cards) != 4:
raise ValueError("Turn阶段必须有4张公共牌")
# 获取剩余的牌可作为River牌
used_cards = player_cards + board_cards
remaining_cards = [card for card in self.full_deck if card not in used_cards]
ehs_values = []
histogram = []
# 对每张可能的River牌计算EHS
for river_card in remaining_cards:
full_board = board_cards + [river_card]
ehs = self.generate_river_ehs(player_cards, full_board)
# ehs_values.append(ehs)
histogram.append(ehs)
if (len(histogram) != num_bins):
print(f" Turn计算: generate nums{len(histogram)} ,num_bins={num_bins}")
raise
# 生成直方图
# histogram, _ = np.histogram(ehs_values, bins=num_bins, range=(0, 1))
# 归一化
# histogram = histogram.astype(float)
# if histogram.sum() > 0:
# histogram /= histogram.sum()
return histogram
def generate_flop_histogram(self, player_cards, board_cards, num_bins: int = 465) -> np.ndarray:
"""
Flop阶段EHS直方图
3张公共牌已知枚举所有可能的Turn+River牌组合计算EHS分布
player_cards: 玩家的2张牌
board_cards: 3张公共牌
num_bins: 直方图bin数量短牌型应该是C(36-5,2) = 465
"""
if len(player_cards) != 2:
raise ValueError("玩家必须有2张牌")
if len(board_cards) != 3:
raise ValueError("Flop阶段必须有3张公共牌")
# 获取剩余的牌
used_cards = player_cards + board_cards
remaining_cards = [card for card in self.full_deck if card not in used_cards]
ehs_values = []
histogram = []
# 枚举所有可能的Turn+River组合C(31,2) = 465
for turn_river_combo in itertools.combinations(remaining_cards, 2):
turn_card, river_card = turn_river_combo
full_board = board_cards + [turn_card, river_card]
ehs = self.generate_river_ehs(player_cards, full_board)
# ehs_values.append(ehs)
histogram.append(ehs)
# # 验证组合数
# expected_combinations = len(list(itertools.combinations(remaining_cards, 2)))
# print(f" Flop计算: C({len(remaining_cards)},2) = {expected_combinations} 种组合")
# 生成直方图
# histogram, _ = np.histogram(ehs_values, bins=num_bins, range=(0, 1))
# 归一化
# histogram = histogram.astype(float)
# if histogram.sum() > 0:
# histogram /= histogram.sum()
if (len(histogram) != num_bins):
print(f" Turn计算: generate nums{len(histogram)} ,num_bins={num_bins}")
raise
return histogram
# # 从一副牌中抽样验证 或者 从解析的数据中抽样验证
# def generate_sample_data(self, num_samples: int = 3) -> Dict:
# print(f"\n 生成样本数据 (每阶段 {num_samples} 个样本)")
# results = {
# 'river': [],
# 'turn': [],
# 'flop': []
# }
# for i in range(num_samples):
# # River样本7张牌
# river_cards = np.random.choice(self.full_deck, size=7, replace=False).tolist()
# player_cards = river_cards[:2]
# board_cards = river_cards[2:7]
# river_ehs = self.generate_river_ehs(player_cards, board_cards)
# results['river'].append({
# 'player_cards': player_cards,
# 'board_cards': board_cards,
# 'ehs': river_ehs
# })
# # Turn样本6张牌
# turn_cards = np.random.choice(self.full_deck, size=6, replace=False).tolist()
# player_cards = turn_cards[:2]
# board_cards = turn_cards[2:6]
# turn_hist = self.generate_turn_histogram(player_cards, board_cards, num_bins=30)
# results['turn'].append({
# 'player_cards': player_cards,
# 'board_cards': board_cards,
# 'histogram': turn_hist,
# 'mean': float(np.mean(turn_hist)),
# 'std': float(np.std(turn_hist))
# })
# # Flop样本5张牌
# flop_cards = np.random.choice(self.full_deck, size=5, replace=False).tolist()
# player_cards = flop_cards[:2]
# board_cards = flop_cards[2:5]
# flop_hist = self.generate_flop_histogram(player_cards, board_cards, num_bins=465)
# results['flop'].append({
# 'player_cards': [str(c) for c in player_cards],
# 'board_cards': [str(c) for c in board_cards],
# 'histogram_stats': flop_hist,
# 'mean': float(np.mean(flop_hist)),
# 'std': float(np.std(flop_hist))
# })
# return results
# def main():
# """测试短牌型EHS直方图生成器"""
# print("短牌型EHS直方图生成器测试")
# print("="*60)
# generator = ShortDeckHistGenerator()
# # 测试River EHS
# print(f"\n测试River EHS计算...")
# test_cards = np.random.choice(generator.full_deck, size=7, replace=False).tolist()
# player_cards = test_cards[:2]
# board_cards = test_cards[2:7]
# player_str = " ".join(str(c) for c in player_cards)
# board_str = " ".join(str(c) for c in board_cards)
# river_ehs = generator.generate_river_ehs(player_cards, board_cards)
# print(f" 玩家底牌: [{player_str}]")
# print(f" 公共牌: [{board_str}]")
# print(f" River EHS: {river_ehs:.4f}")
# # 测试Turn直方图
# print(f"\n测试Turn直方图生成...")
# turn_cards = np.random.choice(generator.full_deck, size=6, replace=False).tolist()
# player_cards = turn_cards[:2]
# board_cards = turn_cards[2:6]
# turn_hist = generator.generate_turn_histogram(player_cards, board_cards, num_bins=30)
# # 测试Flop直方图
# print(f"\n测试Flop直方图生成...")
# flop_cards = np.random.choice(generator.full_deck, size=5, replace=False).tolist()
# player_cards = flop_cards[:2]
# board_cards = flop_cards[2:5]
# flop_hist = generator.generate_flop_histogram(player_cards, board_cards, num_bins=465)
# print(f" Flop直方图: mean={np.mean(flop_hist):.3f}, 非零bins={np.count_nonzero(flop_hist)}")
# # 生成样本数据
# print(f"\n生成样本数据...")
# sample_data = generator.generate_sample_data(num_samples=3)