76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
import sys
|
|
import os
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
from src.game.kuhn_poker import KuhnPoker
|
|
from src.cfr.info_set import InfoSet
|
|
from src.game.player import Player
|
|
|
|
def test_game():
|
|
game = KuhnPoker()
|
|
test_histories = ['', '0', '1', '00', '10', '01', '011', '010', '11']
|
|
for his in test_histories:
|
|
terminal = game.is_terminal(his)
|
|
player = game.get_cur_player(his)
|
|
print(f"'{his}' - teminal: {terminal}, player: {player}")
|
|
|
|
test_cases = [
|
|
(['K', 'J'], '00', 0),
|
|
(['J', 'K'], '00', 0),
|
|
(['Q', 'J'], '10', 0),
|
|
(['J', 'Q'], '010', 0),
|
|
(['K', 'J'], '11', 0),
|
|
(['J', 'Q'], '11', 0),
|
|
(['Q', 'K'], '011', 1),
|
|
]
|
|
for cards, his, player in test_cases:
|
|
profit = game.get_util(cards, his, player)
|
|
print(f"player{player}profit: {profit}")
|
|
|
|
info_cases = [
|
|
('JQ', '' , 0), ('JK', '' , 0),
|
|
('JQ', '0', 0), ('JK', '0', 0),
|
|
('JQ', '1', 1), ('JK', '1', 1),
|
|
('JQ', '00', 0), ('JK', '00', 0),
|
|
('JQ', '10', 0), ('KQ', '10', 0),
|
|
('QK', '01', 1), ('JK', '01', 1),
|
|
('JQ', '010',0), ('JK', '010',0),
|
|
('JQ', '011',1), ('JK', '011',1),
|
|
]
|
|
|
|
info_map = {}
|
|
players = [Player(0), Player(1)]
|
|
|
|
for cards, his, id in info_cases:
|
|
player = players[id]
|
|
cards = [cards[0], cards[1]]
|
|
info_key = game.get_Info_set(cards[id], his)
|
|
player.get_info_set(info_key)
|
|
for p in players:
|
|
print(f"player {p.id} has {len(p.info_map)} info sets.")
|
|
for info_key, info in p.info_map.items():
|
|
info_map[info_key] = info
|
|
print("infokey:", info_key)
|
|
|
|
|
|
# for info_key, info in info_map.items():
|
|
# strat = info.get_strat(1.0)
|
|
# avg_strat = info.get_avg_strat()
|
|
# print(f"infoset: '{info_key}'")
|
|
# print(f" current strategy: [CK/FD: {strat[0]:.3f}, BT/CL: {strat[1]:.3f}]")
|
|
# print(f" average strategy: [CK/FD: {avg_strat[0]:.3f}, BT/CL: {avg_strat[1]:.3f}]")
|
|
|
|
# print('=='*20)
|
|
|
|
# test_seq = ['00', '10', '01', '010', '011', '11']
|
|
|
|
# for seq in test_seq:
|
|
# terminal = game.is_terminal(seq)
|
|
# if terminal:
|
|
# print(f" '{seq}' terminal")
|
|
# else:
|
|
# current_player = game.get_cur_player(seq)
|
|
# print(f" '{seq}' not terminal, cur player: {current_player}")
|
|
|
|
if __name__ == "__main__":
|
|
test_game() |