From cb4ce22775d487f8e779dc77c3b006fa5c0a374a Mon Sep 17 00:00:00 2001 From: jianghaiying Date: Sun, 2 Nov 2025 12:51:13 +0800 Subject: [PATCH] sample --- cross_validation/validator.py | 145 ++++++++++++++++++---------------- 1 file changed, 79 insertions(+), 66 deletions(-) diff --git a/cross_validation/validator.py b/cross_validation/validator.py index 6106bd5..638cd76 100644 --- a/cross_validation/validator.py +++ b/cross_validation/validator.py @@ -29,39 +29,6 @@ class SuitMapping: return self.mapping[s] -class EhsCache: - def __init__(self): - self.cache = defaultdict(lambda: defaultdict(dict)) - - def _set_keys(self, flop, player): - suit_map = SuitMapping() - complex_cards = player+flop - iso_complex = to_iso(complex_cards, suit_map) - complex_key = cards_to_u32(iso_complex) - return complex_key - # 全部存下来计算耗时太大了。。嗯。。 - # todo:不存储直接抽样计算 - def store_river_ehs(self, player, board, ehs): - complex_key = self._set_keys(board[:3],player) - turn_idx = card_index(board[3]) - river_idx = card_index(board[4]) - self.cache[complex_key][turn_idx][river_idx] = ehs - - def get_turn_hist(self, player, flop, turn): - complex_key = self._set_keys(flop, player) - turn_idx = card_index(turn) - turn_hist = self.cache[complex_key][turn_idx] - return list(turn_hist.values()) if turn_hist else None - - def get_flop_hist(self, player, flop): - complex_key = self._set_keys(flop, player) - all_ehs = [] - player_data = self.cache[complex_key] - for turn_idx in player_data: - for river_idx in player_data[turn_idx]: - all_ehs.append(player_data[turn_idx][river_idx]) - return all_ehs if len(all_ehs) == 465 else None - def get_rank_idx(rank: SDR) -> int: rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN, SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE] @@ -161,55 +128,101 @@ def compare_data(sampled, board, player): err_count += 1 return err_count -card_ehs = defaultdict(dict) +def calc_turn_hist(board, player) -> list[float]: + flop = board[:3] + turn = board[3] + used_cards = set(board + player) + + ehs_values = [] + for river in cards: + if river in used_cards: + continue + board = [*flop, turn, river] + ehs = calc_river_ehs(board, player) + ehs_values.append(ehs) + return ehs_values -def validate_river(): - validated_count = 0 +def analysis(flop, player, sampled): + print(f"sampled flop: {''.join(map(str, flop))}") + print(f"sampled player cards: {''.join(map(str, player))}") + + compare_data( + [sampled[t][r] for t in sampled for r in sampled[t] if t > r], + flop, + player, + ) + for turn in sampled: + compare_data(list(sampled[turn].values()), [*flop, turn], player) + + for river in sampled[turn]: + if turn > river: + continue + compare_data(sampled[turn][river], [*flop, turn, river], player) + +def validate_river(n = 100): + all_combos = list(itertools.combinations(cards, 5)) + sampled_combos = random.sample(all_combos, min(n, len(all_combos))) error_count = 0 - for river_combo in itertools.combinations(cards, 5): + for i, river_combo in enumerate(sampled_combos): board = list(river_combo) unused_cards = [c for c in cards if c not in board] - for player_combo in itertools.combinations(unused_cards, 2): - player = list(player_combo) - ehs = calc_river_ehs(board, player) - ehs_stored.store_river_ehs(player, board, ehs) - error_count = compare_data(ehs, board, player) - validated_count += 1 - print(".", end="", flush=True) - print(f"river validate count {validated_count}") - print(f"Validated river hands: {validated_count}, Errors: {error_count}") + + player = list(random.sample(unused_cards, 2)) + ehs = calc_river_ehs(board, player) + error_count = compare_data(ehs, board, player) + # print(f"river validate count: {i}") + print(f"Validated river : {n}, Errors: {error_count}") -def validate_turn(): - validated_count = 0 +def validate_turn(n = 50): error_count = 0 - for turn_combo in itertools.combinations(cards, 4): - turn = list(turn_combo) - unused_cards = [c for c in cards if c not in turn] - for player_combo in itertools.combinations(unused_cards, 2): - player = list(player_combo) - turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3]) - error_count += compare_data(turn_hist, turn, player) - validated_count += 1 - print(".", end="", flush=True) - print(f"Validated turn hands: {validated_count}, Errors: {error_count}") + for i in range(n): + turn_combo = random.sample(cards, 4) + board = list(turn_combo) + unused_cards = [c for c in cards if c not in board] + player = list(random.sample(unused_cards, 2)) + turn_hist = calc_turn_hist(board, player) + error_count += compare_data(turn_hist, board, player) + # print(f"turn validate count: {i}") + print(f"Validated turn : {n}, Errors: {error_count}") + + +CACHE = {} +def calc_river_ehs_cached(board, player) -> float: + suit_map = SuitMapping() + iso_board = to_iso(board, suit_map) + iso_player = to_iso(player, suit_map) + hand = "".join(map(str, [*iso_board, *iso_player])) + if hand not in CACHE: + CACHE[hand] = calc_river_ehs(board, player) + + return CACHE[hand] def validate_flop(): sample = random.sample(cards, 5) flop = sample[2:] player = sample[:2] - - flop_hist = ehs_stored.get_flop_hist(player, flop) - if flop_hist is None: - return - compare_data(flop_hist, flop, player) + sampled_ehs = defaultdict(dict) + for turn in cards: + if turn in flop or turn in player: + continue + + for river in cards: + if river in flop or river in player or river == turn: + continue + + board = [*flop, turn, river] + sampled_ehs[turn][river] = calc_river_ehs_cached(board, player) + print(".", end="", flush=True) + + print("") + analysis(flop, player, sampled_ehs) -ehs_stored = EhsCache() def cross_validate_main(): - validate_river() - validate_turn() + validate_river(1000) + validate_turn(500) validate_flop() if __name__ == "__main__":