kubn cfr:0

This commit is contained in:
2025-11-28 17:19:59 +08:00
commit 547118ec6d
12 changed files with 223 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.13

0
README.md Normal file
View File

13
main.py Normal file
View File

@@ -0,0 +1,13 @@
from src.game.kuhn_poker import KuhnPoker
from src.cfr.info_set import InfoSet
def main():
game = KuhnPoker()
info_set = InfoSet(act_cnt=2)
strategy = info_set.get_strat(.0)
if __name__ == "__main__":
main()

7
pyproject.toml Normal file
View File

@@ -0,0 +1,7 @@
[project]
name = "kubn"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = []

0
src/__init__.py Normal file
View File

3
src/cfr/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .info_set import InfoSet
__all__ = ['InfoSet']

44
src/cfr/info_set.py Normal file
View File

@@ -0,0 +1,44 @@
class InfoSet:
def __init__(self, act_cnt=2): # 0-1
self.act_cnt = act_cnt
self.regret_sum = [0.0] * act_cnt
self.strat = [0.0] * act_cnt
self.strat_sum = [0.0] * act_cnt
def get_strat(self, wgt):
normal = 0.0
for i in range(self.act_cnt):
if self.regret_sum[i] > 0:
self.strat[i] = self.regret_sum[i]
else:
self.strat[i] = 0.0
normal += self.strat[i]
if normal > 0:
for i in range(self.act_cnt):
self.strat[i] /= normal
else:
##
return
for i in range(self.act_cnt):
self.strat_sum[i] += wgt * self.strat[i]
return self.strat
def get_avg_strat(self):
avg_strat = [0.0] * self.act_cnt
normal = sum(self.strat_sum)
if normal > 0:
for i in range(self.act_cnt):
avg_strat[i] = self.strat_sum[i] / normal
else:
##
return
return avg_strat

3
src/game/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .kuhn_poker import KuhnPoker
__all__ = ['KuhnPoker']

61
src/game/kuhn_poker.py Normal file
View File

@@ -0,0 +1,61 @@
class KuhnPoker:
def __init__(self):
self.cards = ['J', 'Q', 'K']
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
def is_terminal(self, history):
return history in ['00', '10', '010', '011', '11']
def get_profit(self, cards, history, player):
if not self.is_terminal(history):
return 0.0
card_values = {'J': 0, 'Q': 1, 'K': 2}
p0_cardv = card_values[cards[0]]
p1_cardv = card_values[cards[1]]
p0_wins = p0_cardv > p1_cardv
if history == '00':
if p0_wins:
return 1.0 if player == 0 else -1.0
else:
return -1.0 if player == 0 else 1.0
elif history == '10':
return 1.0 if player == 0 else -1.0
elif history == '010':
return -1.0 if player == 0 else 1.0
elif history == '011':
if p0_wins:
return 2.0 if player == 0 else -2.0
else:
return -2.0 if player == 0 else 2.0
elif history == '11':
if p0_wins:
return 2.0 if player == 0 else -2.0
else:
return -2.0 if player == 0 else 2.0
return 0.0
def get_Info_set(self, card, history, player):
return f"{card}.{history}"
def get_cur_player(self, history):
if self.is_terminal(history):
return -1
return len(history) % 2
def get_valid_act(self, history):
if self.is_terminal(history):
return []
return [0, 1]
def get_act_name(self, history, action):
if history == "01":
return "Fold" if action == 0 else "Call"
else:
return "Check" if action == 0 else "Bet"

51
tests/test_game.py Normal file
View File

@@ -0,0 +1,51 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.game.kuhn_poker import KuhnPoker
def test_game():
game = KuhnPoker()
test_histories = ['', '0', '1', '00', '10', '01', '011', '010', '11']
for his in test_histories:
terminal = game.is_terminal(his)
player = game.get_cur_player(his)
print(f"'{his}' - teminal: {terminal}, player: {player}")
test_cases = [
(['K', 'J'], '00', 0),
(['J', 'K'], '00', 0),
(['Q', 'J'], '10', 0),
(['J', 'Q'], '010', 0),
(['K', 'J'], '11', 0),
(['J', 'Q'], '11', 0),
(['Q', 'K'], '011', 1),
]
for cards, his, player in test_cases:
profit = game.get_profit(cards, his, player)
print(f"player{player}profit: {profit}")
info_cases = [
('K', '0', 0),
('J', '01', 1),
('Q', '', 0),
]
for card, his, player in info_cases:
info_set = game.get_Info_set(card, his, player)
print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'")
print(f"valid act: {game.get_valid_act('00')}")
test_seq = ['00', '10', '01', '010', '011', '11']
for seq in test_seq:
terminal = game.is_terminal(seq)
if terminal:
print(f" '{seq}' terminal")
else:
current_player = game.get_cur_player(seq)
print(f" '{seq}' not terminal, cur player: {current_player}")
if __name__ == "__main__":
test_game()

30
tests/test_strat.py Normal file
View File

@@ -0,0 +1,30 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.cfr.info_set import InfoSet
def demo_strategy():
info_set = InfoSet(act_cnt=2)
turn = [
([10.0, 5.0], 1.0, "BT 0"),
([15.0, 20.0], 1.0, "CL 0"),
([9.0, 25.0], 1.0, "CL 1"),
([5.0, 30.0], 1.0, "CL 1"),
]
for reg, wgt, str in turn:
info_set.regret_sum = reg[:]
cur_stra = info_set.get_strat(wgt)
avg_stra = info_set.get_avg_strat()
print(f"{str}")
print(f" sum_regret: {reg}")
print(f" cur_strategy: [CK/FD: {cur_stra[0]:.3f}, BT/CL: {cur_stra[1]:.3f}]")
print(f" avg_strategy: [CK/FD: {avg_stra[0]:.3f}, BT/CL: {avg_stra[1]:.3f}]")
print()
if __name__ == "__main__":
demo_strategy()