216 lines
7.0 KiB
Python
216 lines
7.0 KiB
Python
import itertools
|
|
|
|
import numpy as np
|
|
import random
|
|
from collections import defaultdict
|
|
from collections.abc import Iterable
|
|
from pathlib import Path
|
|
|
|
from poker import Suit, Card
|
|
from shortdeck import ShortDeckHandEvaluator as HE
|
|
from shortdeck import ShortDeckRank as SDR
|
|
|
|
data_path = Path(".") / "ehs-data"
|
|
np_river = np.load(data_path / "river_ehs_sd.npy")
|
|
np_turn = np.load(data_path / "turn_hist_sd.npy")
|
|
np_flop = np.load(data_path / "flop_hist_sd.npy")
|
|
|
|
cards = [Card(r, s) for r in SDR for s in Suit]
|
|
|
|
CARD_BITS = 6
|
|
class SuitMapping:
|
|
def __init__(self):
|
|
self.mapping = {}
|
|
self.suits = list(reversed(Suit))
|
|
|
|
def map_suit(self, s: Suit) -> Suit:
|
|
if s not in self.mapping:
|
|
self.mapping[s] = self.suits.pop()
|
|
|
|
return self.mapping[s]
|
|
|
|
class EhsCache:
|
|
def __init__(self):
|
|
self.cache = defaultdict(lambda: defaultdict(dict))
|
|
|
|
def _set_keys(self, flop, player):
|
|
suit_map = SuitMapping()
|
|
complex_cards = player+flop
|
|
iso_complex = to_iso(complex_cards, suit_map)
|
|
complex_key = cards_to_u32(iso_complex)
|
|
return complex_key
|
|
# 全部存下来计算耗时太大了。。嗯。。
|
|
# todo:不存储直接抽样计算
|
|
def store_river_ehs(self, player, board, ehs):
|
|
complex_key = self._set_keys(board[:3],player)
|
|
turn_idx = card_index(board[3])
|
|
river_idx = card_index(board[4])
|
|
self.cache[complex_key][turn_idx][river_idx] = ehs
|
|
|
|
def get_turn_hist(self, player, flop, turn):
|
|
complex_key = self._set_keys(flop, player)
|
|
turn_idx = card_index(turn)
|
|
turn_hist = self.cache[complex_key][turn_idx]
|
|
return list(turn_hist.values()) if turn_hist else None
|
|
|
|
def get_flop_hist(self, player, flop):
|
|
complex_key = self._set_keys(flop, player)
|
|
all_ehs = []
|
|
player_data = self.cache[complex_key]
|
|
for turn_idx in player_data:
|
|
for river_idx in player_data[turn_idx]:
|
|
all_ehs.append(player_data[turn_idx][river_idx])
|
|
return all_ehs if len(all_ehs) == 465 else None
|
|
|
|
def get_rank_idx(rank: SDR) -> int:
|
|
rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN,
|
|
SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE]
|
|
return rank_order.index(rank)
|
|
|
|
def get_suit_idx(suit: Suit) -> int:
|
|
suit_order = [Suit.SPADES, Suit.HEARTS, Suit.DIAMONDS, Suit.CLUBS]
|
|
return suit_order.index(suit)
|
|
|
|
def card_index(card: Card) -> int:
|
|
return (get_rank_idx(card.rank) + 4) * 4 + get_suit_idx(card.suit)
|
|
|
|
Card.__eq__ = lambda a, b: (a.rank == b.rank) and (a.suit == b.suit)
|
|
Card.__hash__ = lambda a: hash((get_rank_idx(a.rank), get_suit_idx(a.suit)))
|
|
|
|
def cards_to_u32(cards: list[Card]) -> int:
|
|
res = 0
|
|
for i, card in enumerate(cards):
|
|
bits = card_index(card) & 0x3F
|
|
res |= bits << (i * CARD_BITS)
|
|
return res
|
|
|
|
def to_iso(cards: list[Card], mapping: SuitMapping) -> list[Card]:
|
|
def count_suit(card: Card) -> int:
|
|
return sum(1 for other in cards if other.suit == card.suit)
|
|
|
|
sorted_cards = sorted(
|
|
cards,
|
|
key=lambda c: (count_suit(c), get_rank_idx(c.rank), get_suit_idx(c.suit))
|
|
)
|
|
res = []
|
|
for card in sorted_cards:
|
|
mapped_suit = mapping.map_suit(card.suit)
|
|
res.append(Card(card.rank, mapped_suit))
|
|
|
|
return sorted(res, key=lambda c: (get_rank_idx(c.rank), get_suit_idx(c.suit)))
|
|
|
|
def cards_to_u16(cards: list[Card]) -> int:
|
|
res = 0
|
|
for i, card in enumerate(cards):
|
|
bits = card_index(card) & 0x3F
|
|
res |= bits << (i * CARD_BITS)
|
|
return res
|
|
|
|
def calc_river_ehs(board: list[Card], player: list[Card]) -> float:
|
|
player_hand = [*board, *player]
|
|
player_ranking = HE.evaluate_hand(player_hand)
|
|
|
|
acc = 0
|
|
sum = 0
|
|
for other in itertools.combinations(cards, 2):
|
|
if set(other) & set(player_hand):
|
|
continue
|
|
if set(other) & set(board):
|
|
continue
|
|
other_ranking = HE.evaluate_hand([*board, *other])
|
|
if player_ranking == other_ranking:
|
|
acc += 1
|
|
elif player_ranking > other_ranking:
|
|
acc += 2
|
|
sum += 2
|
|
return acc / sum
|
|
|
|
def get_data(board: list[Card], player: list[Card]):
|
|
def _get_data(data, board: list[Card], player: list[Card]):
|
|
suit_map = SuitMapping()
|
|
|
|
iso_board = to_iso(board, suit_map)
|
|
iso_player = to_iso(player, suit_map)
|
|
|
|
mask_board = data["board"] == cards_to_u32(iso_board)
|
|
mask_player = data["player"] == cards_to_u16(iso_player)
|
|
return data[mask_board & mask_player][0][2]
|
|
match len(board):
|
|
case 3:
|
|
return _get_data(np_flop, board, player)
|
|
case 4:
|
|
return _get_data(np_turn, board, player)
|
|
case 5:
|
|
return _get_data(np_river, board, player)
|
|
case _:
|
|
raise NotImplementedError
|
|
|
|
def euclidean_dist(left, right):
|
|
if isinstance(left, Iterable):
|
|
v1 = np.sort(np.array(left, dtype=np.float32))
|
|
v2 = np.sort(np.array(right, dtype=np.float32))
|
|
return np.linalg.norm(v2 - v1)
|
|
else:
|
|
return np.abs(left - right) ** 2
|
|
|
|
def compare_data(sampled, board, player):
|
|
err_count = 0
|
|
d = euclidean_dist(get_data(board, player), sampled)
|
|
if not np.isclose(d, 0.0):
|
|
print(f"[{''.join(map(str, board))} {''.join(map(str, player))}]: {d}")
|
|
err_count += 1
|
|
return err_count
|
|
|
|
card_ehs = defaultdict(dict)
|
|
|
|
def validate_river():
|
|
validated_count = 0
|
|
error_count = 0
|
|
for river_combo in itertools.combinations(cards, 5):
|
|
board = list(river_combo)
|
|
unused_cards = [c for c in cards if c not in board]
|
|
for player_combo in itertools.combinations(unused_cards, 2):
|
|
player = list(player_combo)
|
|
ehs = calc_river_ehs(board, player)
|
|
ehs_stored.store_river_ehs(player, board, ehs)
|
|
error_count = compare_data(ehs, board, player)
|
|
validated_count += 1
|
|
print(".", end="", flush=True)
|
|
print(f"river validate count {validated_count}")
|
|
print(f"Validated river hands: {validated_count}, Errors: {error_count}")
|
|
|
|
|
|
def validate_turn():
|
|
validated_count = 0
|
|
error_count = 0
|
|
for turn_combo in itertools.combinations(cards, 4):
|
|
turn = list(turn_combo)
|
|
unused_cards = [c for c in cards if c not in turn]
|
|
for player_combo in itertools.combinations(unused_cards, 2):
|
|
player = list(player_combo)
|
|
turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3])
|
|
error_count += compare_data(turn_hist, turn, player)
|
|
validated_count += 1
|
|
print(".", end="", flush=True)
|
|
print(f"Validated turn hands: {validated_count}, Errors: {error_count}")
|
|
|
|
def validate_flop():
|
|
sample = random.sample(cards, 5)
|
|
flop = sample[2:]
|
|
player = sample[:2]
|
|
|
|
flop_hist = ehs_stored.get_flop_hist(player, flop)
|
|
if flop_hist is None:
|
|
return
|
|
compare_data(flop_hist, flop, player)
|
|
|
|
|
|
ehs_stored = EhsCache()
|
|
|
|
def cross_validate_main():
|
|
validate_river()
|
|
validate_turn()
|
|
validate_flop()
|
|
|
|
if __name__ == "__main__":
|
|
cross_validate_main() |