This commit is contained in:
2025-12-03 17:49:29 +08:00
parent 8e4be3bda2
commit 33ca6d59b0
6 changed files with 241 additions and 91 deletions

View File

@@ -3,6 +3,8 @@ import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.game.kuhn_poker import KuhnPoker
from src.cfr.info_set import InfoSet
from src.game.player import Player
def test_game():
game = KuhnPoker()
@@ -22,30 +24,53 @@ def test_game():
(['Q', 'K'], '011', 1),
]
for cards, his, player in test_cases:
profit = game.get_profit(cards, his, player)
profit = game.get_util(cards, his, player)
print(f"player{player}profit: {profit}")
info_cases = [
('K', '0', 0),
('J', '01', 1),
('Q', '', 0),
('JQ', '' , 0), ('JK', '' , 0),
('JQ', '0', 0), ('JK', '0', 0),
('JQ', '1', 1), ('JK', '1', 1),
('JQ', '00', 0), ('JK', '00', 0),
('JQ', '10', 0), ('KQ', '10', 0),
('QK', '01', 1), ('JK', '01', 1),
('JQ', '010',0), ('JK', '010',0),
('JQ', '011',1), ('JK', '011',1),
]
for card, his, player in info_cases:
info_set = game.get_Info_set(card, his, player)
print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'")
print(f"valid act: {game.get_valid_act('00')}")
test_seq = ['00', '10', '01', '010', '011', '11']
for seq in test_seq:
terminal = game.is_terminal(seq)
if terminal:
print(f" '{seq}' terminal")
else:
current_player = game.get_cur_player(seq)
print(f" '{seq}' not terminal, cur player: {current_player}")
info_map = {}
players = [Player(0), Player(1)]
for cards, his, id in info_cases:
player = players[id]
cards = [cards[0], cards[1]]
info_key = game.get_Info_set(cards[id], his)
player.get_info_set(info_key)
for p in players:
print(f"player {p.id} has {len(p.info_map)} info sets.")
for info_key, info in p.info_map.items():
info_map[info_key] = info
print("infokey:", info_key)
# for info_key, info in info_map.items():
# strat = info.get_strat(1.0)
# avg_strat = info.get_avg_strat()
# print(f"infoset: '{info_key}'")
# print(f" current strategy: [CK/FD: {strat[0]:.3f}, BT/CL: {strat[1]:.3f}]")
# print(f" average strategy: [CK/FD: {avg_strat[0]:.3f}, BT/CL: {avg_strat[1]:.3f}]")
# print('=='*20)
# test_seq = ['00', '10', '01', '010', '011', '11']
# for seq in test_seq:
# terminal = game.is_terminal(seq)
# if terminal:
# print(f" '{seq}' terminal")
# else:
# current_player = game.get_cur_player(seq)
# print(f" '{seq}' not terminal, cur player: {current_player}")
if __name__ == "__main__":
test_game()