This commit is contained in:
2025-11-02 12:51:13 +08:00
parent ec50d2897f
commit cb4ce22775

View File

@@ -29,39 +29,6 @@ class SuitMapping:
return self.mapping[s]
class EhsCache:
def __init__(self):
self.cache = defaultdict(lambda: defaultdict(dict))
def _set_keys(self, flop, player):
suit_map = SuitMapping()
complex_cards = player+flop
iso_complex = to_iso(complex_cards, suit_map)
complex_key = cards_to_u32(iso_complex)
return complex_key
# 全部存下来计算耗时太大了。。嗯。。
# todo:不存储直接抽样计算
def store_river_ehs(self, player, board, ehs):
complex_key = self._set_keys(board[:3],player)
turn_idx = card_index(board[3])
river_idx = card_index(board[4])
self.cache[complex_key][turn_idx][river_idx] = ehs
def get_turn_hist(self, player, flop, turn):
complex_key = self._set_keys(flop, player)
turn_idx = card_index(turn)
turn_hist = self.cache[complex_key][turn_idx]
return list(turn_hist.values()) if turn_hist else None
def get_flop_hist(self, player, flop):
complex_key = self._set_keys(flop, player)
all_ehs = []
player_data = self.cache[complex_key]
for turn_idx in player_data:
for river_idx in player_data[turn_idx]:
all_ehs.append(player_data[turn_idx][river_idx])
return all_ehs if len(all_ehs) == 465 else None
def get_rank_idx(rank: SDR) -> int:
rank_order = [SDR.SIX, SDR.SEVEN, SDR.EIGHT, SDR.NINE, SDR.TEN,
SDR.JACK,SDR.QUEEN, SDR.KING, SDR.ACE]
@@ -161,55 +128,101 @@ def compare_data(sampled, board, player):
err_count += 1
return err_count
card_ehs = defaultdict(dict)
def calc_turn_hist(board, player) -> list[float]:
flop = board[:3]
turn = board[3]
used_cards = set(board + player)
def validate_river():
validated_count = 0
ehs_values = []
for river in cards:
if river in used_cards:
continue
board = [*flop, turn, river]
ehs = calc_river_ehs(board, player)
ehs_values.append(ehs)
return ehs_values
def analysis(flop, player, sampled):
print(f"sampled flop: {''.join(map(str, flop))}")
print(f"sampled player cards: {''.join(map(str, player))}")
compare_data(
[sampled[t][r] for t in sampled for r in sampled[t] if t > r],
flop,
player,
)
for turn in sampled:
compare_data(list(sampled[turn].values()), [*flop, turn], player)
for river in sampled[turn]:
if turn > river:
continue
compare_data(sampled[turn][river], [*flop, turn, river], player)
def validate_river(n = 100):
all_combos = list(itertools.combinations(cards, 5))
sampled_combos = random.sample(all_combos, min(n, len(all_combos)))
error_count = 0
for river_combo in itertools.combinations(cards, 5):
for i, river_combo in enumerate(sampled_combos):
board = list(river_combo)
unused_cards = [c for c in cards if c not in board]
for player_combo in itertools.combinations(unused_cards, 2):
player = list(player_combo)
player = list(random.sample(unused_cards, 2))
ehs = calc_river_ehs(board, player)
ehs_stored.store_river_ehs(player, board, ehs)
error_count = compare_data(ehs, board, player)
validated_count += 1
print(".", end="", flush=True)
print(f"river validate count {validated_count}")
print(f"Validated river hands: {validated_count}, Errors: {error_count}")
# print(f"river validate count: {i}")
print(f"Validated river : {n}, Errors: {error_count}")
def validate_turn():
validated_count = 0
def validate_turn(n = 50):
error_count = 0
for turn_combo in itertools.combinations(cards, 4):
turn = list(turn_combo)
unused_cards = [c for c in cards if c not in turn]
for player_combo in itertools.combinations(unused_cards, 2):
player = list(player_combo)
turn_hist = ehs_stored.get_turn_hist(player, turn[:3], turn[3])
error_count += compare_data(turn_hist, turn, player)
validated_count += 1
print(".", end="", flush=True)
print(f"Validated turn hands: {validated_count}, Errors: {error_count}")
for i in range(n):
turn_combo = random.sample(cards, 4)
board = list(turn_combo)
unused_cards = [c for c in cards if c not in board]
player = list(random.sample(unused_cards, 2))
turn_hist = calc_turn_hist(board, player)
error_count += compare_data(turn_hist, board, player)
# print(f"turn validate count: {i}")
print(f"Validated turn : {n}, Errors: {error_count}")
CACHE = {}
def calc_river_ehs_cached(board, player) -> float:
suit_map = SuitMapping()
iso_board = to_iso(board, suit_map)
iso_player = to_iso(player, suit_map)
hand = "".join(map(str, [*iso_board, *iso_player]))
if hand not in CACHE:
CACHE[hand] = calc_river_ehs(board, player)
return CACHE[hand]
def validate_flop():
sample = random.sample(cards, 5)
flop = sample[2:]
player = sample[:2]
sampled_ehs = defaultdict(dict)
flop_hist = ehs_stored.get_flop_hist(player, flop)
if flop_hist is None:
return
compare_data(flop_hist, flop, player)
for turn in cards:
if turn in flop or turn in player:
continue
for river in cards:
if river in flop or river in player or river == turn:
continue
board = [*flop, turn, river]
sampled_ehs[turn][river] = calc_river_ehs_cached(board, player)
print(".", end="", flush=True)
print("")
analysis(flop, player, sampled_ehs)
ehs_stored = EhsCache()
def cross_validate_main():
validate_river()
validate_turn()
validate_river(1000)
validate_turn(500)
validate_flop()
if __name__ == "__main__":