kubn cfr:0

This commit is contained in:
2025-11-28 17:19:59 +08:00
commit 547118ec6d
12 changed files with 223 additions and 0 deletions

51
tests/test_game.py Normal file
View File

@@ -0,0 +1,51 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.game.kuhn_poker import KuhnPoker
def test_game():
game = KuhnPoker()
test_histories = ['', '0', '1', '00', '10', '01', '011', '010', '11']
for his in test_histories:
terminal = game.is_terminal(his)
player = game.get_cur_player(his)
print(f"'{his}' - teminal: {terminal}, player: {player}")
test_cases = [
(['K', 'J'], '00', 0),
(['J', 'K'], '00', 0),
(['Q', 'J'], '10', 0),
(['J', 'Q'], '010', 0),
(['K', 'J'], '11', 0),
(['J', 'Q'], '11', 0),
(['Q', 'K'], '011', 1),
]
for cards, his, player in test_cases:
profit = game.get_profit(cards, his, player)
print(f"player{player}profit: {profit}")
info_cases = [
('K', '0', 0),
('J', '01', 1),
('Q', '', 0),
]
for card, his, player in info_cases:
info_set = game.get_Info_set(card, his, player)
print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'")
print(f"valid act: {game.get_valid_act('00')}")
test_seq = ['00', '10', '01', '010', '011', '11']
for seq in test_seq:
terminal = game.is_terminal(seq)
if terminal:
print(f" '{seq}' terminal")
else:
current_player = game.get_cur_player(seq)
print(f" '{seq}' not terminal, cur player: {current_player}")
if __name__ == "__main__":
test_game()

30
tests/test_strat.py Normal file
View File

@@ -0,0 +1,30 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.cfr.info_set import InfoSet
def demo_strategy():
info_set = InfoSet(act_cnt=2)
turn = [
([10.0, 5.0], 1.0, "BT 0"),
([15.0, 20.0], 1.0, "CL 0"),
([9.0, 25.0], 1.0, "CL 1"),
([5.0, 30.0], 1.0, "CL 1"),
]
for reg, wgt, str in turn:
info_set.regret_sum = reg[:]
cur_stra = info_set.get_strat(wgt)
avg_stra = info_set.get_avg_strat()
print(f"{str}")
print(f" sum_regret: {reg}")
print(f" cur_strategy: [CK/FD: {cur_stra[0]:.3f}, BT/CL: {cur_stra[1]:.3f}]")
print(f" avg_strategy: [CK/FD: {avg_stra[0]:.3f}, BT/CL: {avg_stra[1]:.3f}]")
print()
if __name__ == "__main__":
demo_strategy()