import itertools import numpy as np import random from collections import defaultdict from collections.abc import Iterable from pathlib import Path from poker import Suit, Card from shortdeck import ShortDeckHandEvaluator as HE from shortdeck import ShortDeckRank as SDR data_path = Path(".") / "ehs-data" np_river = np.load(data_path / "river_ehs_sd.npy") np_turn = np.load(data_path / "turn_hist_sd.npy") np_flop = np.load(data_path / "flop_hist_sd.npy") cards = [Card(r, s) for r in SDR for s in Suit] CARD_BITS = 6 class SuitMapping: def __init__(self): self.mapping = {} self.suits = list(reversed(Suit)) def map_suit(self, s: Suit) -> Suit: if s not in self.mapping: self.mapping[s] = self.suits.pop() return self.mapping[s] class EhsCache: def __init__(self): self.cache = defaultdict(lambda: defaultdict(dict)) def _set_keys(self, flop, player): suit_map = SuitMapping() complex_cards = player+flop iso_complex = to_iso(complex_cards, suit_map) complex_key = cards_to_u32(iso_complex) return complex_key # 全部存下来计算耗时太大了。。嗯。。 # todo:不存储直接抽样计算 def store_river_ehs(self, player, board, ehs): complex_key = self._set_keys(board[:3],player) turn_idx = card_index(board[3]) river_idx = card_index(board[4]) self.cache[complex_key][turn_idx][river_idx] = ehs def get_turn_hist(self, player, flop, turn): complex_key = self._set_keys(flop, player) turn_idx = card_index(turn) turn_hist = self.cache[complex_key][turn_idx] return list(turn_hist.values()) if turn_hist else None def get_flop_hist(self, player, flop): complex_key = self._set_keys(flop, player) all_ehs = [] player_data = self.cache[complex_key] for turn_idx in player_data: for river_idx in player_data[turn_idx]: all_ehs.append(player_data[turn_idx][river_idx]) return all_ehs if len(all_ehs) == 465 else None def get_rank_idx(rank: SDR) -> int: rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN, SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE] return rank_order.index(rank) def get_suit_idx(suit: Suit) -> int: suit_order = [Suit.SPADES, Suit.HEARTS, Suit.DIAMONDS, Suit.CLUBS] return suit_order.index(suit) def card_index(card: Card) -> int: return (get_rank_idx(card.rank) + 4) * 4 + get_suit_idx(card.suit) Card.__eq__ = lambda a, b: (a.rank == b.rank) and (a.suit == b.suit) Card.__hash__ = lambda a: hash((get_rank_idx(a.rank), get_suit_idx(a.suit))) def cards_to_u32(cards: list[Card]) -> int: res = 0 for i, card in enumerate(cards): bits = card_index(card) & 0x3F res |= bits << (i * CARD_BITS) return res def to_iso(cards: list[Card], mapping: SuitMapping) -> list[Card]: def count_suit(card: Card) -> int: return sum(1 for other in cards if other.suit == card.suit) sorted_cards = sorted( cards, key=lambda c: (count_suit(c), get_rank_idx(c.rank), get_suit_idx(c.suit)) ) res = [] for card in sorted_cards: mapped_suit = mapping.map_suit(card.suit) res.append(Card(card.rank, mapped_suit)) return sorted(res, key=lambda c: (get_rank_idx(c.rank), get_suit_idx(c.suit))) def cards_to_u16(cards: list[Card]) -> int: res = 0 for i, card in enumerate(cards): bits = card_index(card) & 0x3F res |= bits << (i * CARD_BITS) return res def calc_river_ehs(board: list[Card], player: list[Card]) -> float: player_hand = [*board, *player] player_ranking = HE.evaluate_hand(player_hand) acc = 0 sum = 0 for other in itertools.combinations(cards, 2): if set(other) & set(player_hand): continue if set(other) & set(board): continue other_ranking = HE.evaluate_hand([*board, *other]) if player_ranking == other_ranking: acc += 1 elif player_ranking > other_ranking: acc += 2 sum += 2 return acc / sum def get_data(board: list[Card], player: list[Card]): def _get_data(data, board: list[Card], player: list[Card]): suit_map = SuitMapping() iso_board = to_iso(board, suit_map) iso_player = to_iso(player, suit_map) mask_board = data["board"] == cards_to_u32(iso_board) mask_player = data["player"] == cards_to_u16(iso_player) return data[mask_board & mask_player][0][2] match len(board): case 3: return _get_data(np_flop, board, player) case 4: return _get_data(np_turn, board, player) case 5: return _get_data(np_river, board, player) case _: raise NotImplementedError def euclidean_dist(left, right): if isinstance(left, Iterable): v1 = np.sort(np.array(left, dtype=np.float32)) v2 = np.sort(np.array(right, dtype=np.float32)) return np.linalg.norm(v2 - v1) else: return np.abs(left - right) ** 2 def compare_data(sampled, board, player): err_count = 0 d = euclidean_dist(get_data(board, player), sampled) if not np.isclose(d, 0.0): print(f"[{''.join(map(str, board))} {''.join(map(str, player))}]: {d}") err_count += 1 return err_count card_ehs = defaultdict(dict) def validate_river(): validated_count = 0 error_count = 0 for river_combo in itertools.combinations(cards, 5): board = list(river_combo) unused_cards = [c for c in cards if c not in board] for player_combo in itertools.combinations(unused_cards, 2): player = list(player_combo) ehs = calc_river_ehs(board, player) ehs_stored.store_river_ehs(player, board, ehs) error_count = compare_data(ehs, board, player) validated_count += 1 print(".", end="", flush=True) print(f"river validate count {validated_count}") print(f"Validated river hands: {validated_count}, Errors: {error_count}") def validate_turn(): validated_count = 0 error_count = 0 for turn_combo in itertools.combinations(cards, 4): turn = list(turn_combo) unused_cards = [c for c in cards if c not in turn] for player_combo in itertools.combinations(unused_cards, 2): player = list(player_combo) turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3]) error_count += compare_data(turn_hist, turn, player) validated_count += 1 print(".", end="", flush=True) print(f"Validated turn hands: {validated_count}, Errors: {error_count}") def validate_flop(): sample = random.sample(cards, 5) flop = sample[2:] player = sample[:2] flop_hist = ehs_stored.get_flop_hist(player, flop) if flop_hist is None: return compare_data(flop_hist, flop, player) ehs_stored = EhsCache() def cross_validate_main(): validate_river() validate_turn() validate_flop() if __name__ == "__main__": cross_validate_main()