Files
poker_task1/cross_validation/validator.py
2025-11-02 11:13:44 +08:00

216 lines
7.0 KiB
Python

import itertools
import numpy as np
import random
from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path
from poker import Suit, Card
from shortdeck import ShortDeckHandEvaluator as HE
from shortdeck import ShortDeckRank as SDR
data_path = Path(".") / "ehs-data"
np_river = np.load(data_path / "river_ehs_sd.npy")
np_turn = np.load(data_path / "turn_hist_sd.npy")
np_flop = np.load(data_path / "flop_hist_sd.npy")
cards = [Card(r, s) for r in SDR for s in Suit]
CARD_BITS = 6
class SuitMapping:
def __init__(self):
self.mapping = {}
self.suits = list(reversed(Suit))
def map_suit(self, s: Suit) -> Suit:
if s not in self.mapping:
self.mapping[s] = self.suits.pop()
return self.mapping[s]
class EhsCache:
def __init__(self):
self.cache = defaultdict(lambda: defaultdict(dict))
def _set_keys(self, flop, player):
suit_map = SuitMapping()
complex_cards = player+flop
iso_complex = to_iso(complex_cards, suit_map)
complex_key = cards_to_u32(iso_complex)
return complex_key
# 全部存下来计算耗时太大了。。嗯。。
# todo:不存储直接抽样计算
def store_river_ehs(self, player, board, ehs):
complex_key = self._set_keys(board[:3],player)
turn_idx = card_index(board[3])
river_idx = card_index(board[4])
self.cache[complex_key][turn_idx][river_idx] = ehs
def get_turn_hist(self, player, flop, turn):
complex_key = self._set_keys(flop, player)
turn_idx = card_index(turn)
turn_hist = self.cache[complex_key][turn_idx]
return list(turn_hist.values()) if turn_hist else None
def get_flop_hist(self, player, flop):
complex_key = self._set_keys(flop, player)
all_ehs = []
player_data = self.cache[complex_key]
for turn_idx in player_data:
for river_idx in player_data[turn_idx]:
all_ehs.append(player_data[turn_idx][river_idx])
return all_ehs if len(all_ehs) == 465 else None
def get_rank_idx(rank: SDR) -> int:
rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN,
SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE]
return rank_order.index(rank)
def get_suit_idx(suit: Suit) -> int:
suit_order = [Suit.SPADES, Suit.HEARTS, Suit.DIAMONDS, Suit.CLUBS]
return suit_order.index(suit)
def card_index(card: Card) -> int:
return (get_rank_idx(card.rank) + 4) * 4 + get_suit_idx(card.suit)
Card.__eq__ = lambda a, b: (a.rank == b.rank) and (a.suit == b.suit)
Card.__hash__ = lambda a: hash((get_rank_idx(a.rank), get_suit_idx(a.suit)))
def cards_to_u32(cards: list[Card]) -> int:
res = 0
for i, card in enumerate(cards):
bits = card_index(card) & 0x3F
res |= bits << (i * CARD_BITS)
return res
def to_iso(cards: list[Card], mapping: SuitMapping) -> list[Card]:
def count_suit(card: Card) -> int:
return sum(1 for other in cards if other.suit == card.suit)
sorted_cards = sorted(
cards,
key=lambda c: (count_suit(c), get_rank_idx(c.rank), get_suit_idx(c.suit))
)
res = []
for card in sorted_cards:
mapped_suit = mapping.map_suit(card.suit)
res.append(Card(card.rank, mapped_suit))
return sorted(res, key=lambda c: (get_rank_idx(c.rank), get_suit_idx(c.suit)))
def cards_to_u16(cards: list[Card]) -> int:
res = 0
for i, card in enumerate(cards):
bits = card_index(card) & 0x3F
res |= bits << (i * CARD_BITS)
return res
def calc_river_ehs(board: list[Card], player: list[Card]) -> float:
player_hand = [*board, *player]
player_ranking = HE.evaluate_hand(player_hand)
acc = 0
sum = 0
for other in itertools.combinations(cards, 2):
if set(other) & set(player_hand):
continue
if set(other) & set(board):
continue
other_ranking = HE.evaluate_hand([*board, *other])
if player_ranking == other_ranking:
acc += 1
elif player_ranking > other_ranking:
acc += 2
sum += 2
return acc / sum
def get_data(board: list[Card], player: list[Card]):
def _get_data(data, board: list[Card], player: list[Card]):
suit_map = SuitMapping()
iso_board = to_iso(board, suit_map)
iso_player = to_iso(player, suit_map)
mask_board = data["board"] == cards_to_u32(iso_board)
mask_player = data["player"] == cards_to_u16(iso_player)
return data[mask_board & mask_player][0][2]
match len(board):
case 3:
return _get_data(np_flop, board, player)
case 4:
return _get_data(np_turn, board, player)
case 5:
return _get_data(np_river, board, player)
case _:
raise NotImplementedError
def euclidean_dist(left, right):
if isinstance(left, Iterable):
v1 = np.sort(np.array(left, dtype=np.float32))
v2 = np.sort(np.array(right, dtype=np.float32))
return np.linalg.norm(v2 - v1)
else:
return np.abs(left - right) ** 2
def compare_data(sampled, board, player):
err_count = 0
d = euclidean_dist(get_data(board, player), sampled)
if not np.isclose(d, 0.0):
print(f"[{''.join(map(str, board))} {''.join(map(str, player))}]: {d}")
err_count += 1
return err_count
card_ehs = defaultdict(dict)
def validate_river():
validated_count = 0
error_count = 0
for river_combo in itertools.combinations(cards, 5):
board = list(river_combo)
unused_cards = [c for c in cards if c not in board]
for player_combo in itertools.combinations(unused_cards, 2):
player = list(player_combo)
ehs = calc_river_ehs(board, player)
ehs_stored.store_river_ehs(player, board, ehs)
error_count = compare_data(ehs, board, player)
validated_count += 1
print(".", end="", flush=True)
print(f"river validate count {validated_count}")
print(f"Validated river hands: {validated_count}, Errors: {error_count}")
def validate_turn():
validated_count = 0
error_count = 0
for turn_combo in itertools.combinations(cards, 4):
turn = list(turn_combo)
unused_cards = [c for c in cards if c not in turn]
for player_combo in itertools.combinations(unused_cards, 2):
player = list(player_combo)
turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3])
error_count += compare_data(turn_hist, turn, player)
validated_count += 1
print(".", end="", flush=True)
print(f"Validated turn hands: {validated_count}, Errors: {error_count}")
def validate_flop():
sample = random.sample(cards, 5)
flop = sample[2:]
player = sample[:2]
flop_hist = ehs_stored.get_flop_hist(player, flop)
if flop_hist is None:
return
compare_data(flop_hist, flop, player)
ehs_stored = EhsCache()
def cross_validate_main():
validate_river()
validate_turn()
validate_flop()
if __name__ == "__main__":
cross_validate_main()