kubn cfr:0
This commit is contained in:
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.13
|
||||||
13
main.py
Normal file
13
main.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from src.game.kuhn_poker import KuhnPoker
|
||||||
|
from src.cfr.info_set import InfoSet
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
game = KuhnPoker()
|
||||||
|
|
||||||
|
info_set = InfoSet(act_cnt=2)
|
||||||
|
strategy = info_set.get_strat(.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
7
pyproject.toml
Normal file
7
pyproject.toml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[project]
|
||||||
|
name = "kubn"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = []
|
||||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
3
src/cfr/__init__.py
Normal file
3
src/cfr/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .info_set import InfoSet
|
||||||
|
|
||||||
|
__all__ = ['InfoSet']
|
||||||
44
src/cfr/info_set.py
Normal file
44
src/cfr/info_set.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
class InfoSet:
|
||||||
|
def __init__(self, act_cnt=2): # 0-1
|
||||||
|
self.act_cnt = act_cnt
|
||||||
|
|
||||||
|
self.regret_sum = [0.0] * act_cnt
|
||||||
|
self.strat = [0.0] * act_cnt
|
||||||
|
self.strat_sum = [0.0] * act_cnt
|
||||||
|
|
||||||
|
def get_strat(self, wgt):
|
||||||
|
|
||||||
|
normal = 0.0
|
||||||
|
for i in range(self.act_cnt):
|
||||||
|
if self.regret_sum[i] > 0:
|
||||||
|
self.strat[i] = self.regret_sum[i]
|
||||||
|
else:
|
||||||
|
self.strat[i] = 0.0
|
||||||
|
normal += self.strat[i]
|
||||||
|
|
||||||
|
if normal > 0:
|
||||||
|
for i in range(self.act_cnt):
|
||||||
|
self.strat[i] /= normal
|
||||||
|
else:
|
||||||
|
##
|
||||||
|
return
|
||||||
|
|
||||||
|
for i in range(self.act_cnt):
|
||||||
|
self.strat_sum[i] += wgt * self.strat[i]
|
||||||
|
|
||||||
|
return self.strat
|
||||||
|
|
||||||
|
def get_avg_strat(self):
|
||||||
|
avg_strat = [0.0] * self.act_cnt
|
||||||
|
normal = sum(self.strat_sum)
|
||||||
|
|
||||||
|
if normal > 0:
|
||||||
|
for i in range(self.act_cnt):
|
||||||
|
avg_strat[i] = self.strat_sum[i] / normal
|
||||||
|
else:
|
||||||
|
##
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
return avg_strat
|
||||||
|
|
||||||
3
src/game/__init__.py
Normal file
3
src/game/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .kuhn_poker import KuhnPoker
|
||||||
|
|
||||||
|
__all__ = ['KuhnPoker']
|
||||||
61
src/game/kuhn_poker.py
Normal file
61
src/game/kuhn_poker.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
class KuhnPoker:
|
||||||
|
def __init__(self):
|
||||||
|
self.cards = ['J', 'Q', 'K']
|
||||||
|
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
|
||||||
|
|
||||||
|
def is_terminal(self, history):
|
||||||
|
return history in ['00', '10', '010', '011', '11']
|
||||||
|
|
||||||
|
def get_profit(self, cards, history, player):
|
||||||
|
if not self.is_terminal(history):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
card_values = {'J': 0, 'Q': 1, 'K': 2}
|
||||||
|
p0_cardv = card_values[cards[0]]
|
||||||
|
p1_cardv = card_values[cards[1]]
|
||||||
|
|
||||||
|
p0_wins = p0_cardv > p1_cardv
|
||||||
|
if history == '00':
|
||||||
|
if p0_wins:
|
||||||
|
return 1.0 if player == 0 else -1.0
|
||||||
|
else:
|
||||||
|
return -1.0 if player == 0 else 1.0
|
||||||
|
|
||||||
|
elif history == '10':
|
||||||
|
return 1.0 if player == 0 else -1.0
|
||||||
|
|
||||||
|
elif history == '010':
|
||||||
|
return -1.0 if player == 0 else 1.0
|
||||||
|
|
||||||
|
elif history == '011':
|
||||||
|
if p0_wins:
|
||||||
|
return 2.0 if player == 0 else -2.0
|
||||||
|
else:
|
||||||
|
return -2.0 if player == 0 else 2.0
|
||||||
|
|
||||||
|
elif history == '11':
|
||||||
|
if p0_wins:
|
||||||
|
return 2.0 if player == 0 else -2.0
|
||||||
|
else:
|
||||||
|
return -2.0 if player == 0 else 2.0
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def get_Info_set(self, card, history, player):
|
||||||
|
return f"{card}.{history}"
|
||||||
|
|
||||||
|
def get_cur_player(self, history):
|
||||||
|
if self.is_terminal(history):
|
||||||
|
return -1
|
||||||
|
return len(history) % 2
|
||||||
|
|
||||||
|
def get_valid_act(self, history):
|
||||||
|
if self.is_terminal(history):
|
||||||
|
return []
|
||||||
|
return [0, 1]
|
||||||
|
|
||||||
|
def get_act_name(self, history, action):
|
||||||
|
if history == "01":
|
||||||
|
return "Fold" if action == 0 else "Call"
|
||||||
|
else:
|
||||||
|
return "Check" if action == 0 else "Bet"
|
||||||
51
tests/test_game.py
Normal file
51
tests/test_game.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
from src.game.kuhn_poker import KuhnPoker
|
||||||
|
|
||||||
|
def test_game():
|
||||||
|
game = KuhnPoker()
|
||||||
|
test_histories = ['', '0', '1', '00', '10', '01', '011', '010', '11']
|
||||||
|
for his in test_histories:
|
||||||
|
terminal = game.is_terminal(his)
|
||||||
|
player = game.get_cur_player(his)
|
||||||
|
print(f"'{his}' - teminal: {terminal}, player: {player}")
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
(['K', 'J'], '00', 0),
|
||||||
|
(['J', 'K'], '00', 0),
|
||||||
|
(['Q', 'J'], '10', 0),
|
||||||
|
(['J', 'Q'], '010', 0),
|
||||||
|
(['K', 'J'], '11', 0),
|
||||||
|
(['J', 'Q'], '11', 0),
|
||||||
|
(['Q', 'K'], '011', 1),
|
||||||
|
]
|
||||||
|
for cards, his, player in test_cases:
|
||||||
|
profit = game.get_profit(cards, his, player)
|
||||||
|
print(f"player{player}profit: {profit}")
|
||||||
|
|
||||||
|
info_cases = [
|
||||||
|
('K', '0', 0),
|
||||||
|
('J', '01', 1),
|
||||||
|
('Q', '', 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
for card, his, player in info_cases:
|
||||||
|
info_set = game.get_Info_set(card, his, player)
|
||||||
|
print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'")
|
||||||
|
|
||||||
|
print(f"valid act: {game.get_valid_act('00')}")
|
||||||
|
|
||||||
|
test_seq = ['00', '10', '01', '010', '011', '11']
|
||||||
|
|
||||||
|
for seq in test_seq:
|
||||||
|
terminal = game.is_terminal(seq)
|
||||||
|
if terminal:
|
||||||
|
print(f" '{seq}' terminal")
|
||||||
|
else:
|
||||||
|
current_player = game.get_cur_player(seq)
|
||||||
|
print(f" '{seq}' not terminal, cur player: {current_player}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_game()
|
||||||
30
tests/test_strat.py
Normal file
30
tests/test_strat.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
from src.cfr.info_set import InfoSet
|
||||||
|
|
||||||
|
def demo_strategy():
|
||||||
|
|
||||||
|
info_set = InfoSet(act_cnt=2)
|
||||||
|
|
||||||
|
turn = [
|
||||||
|
([10.0, 5.0], 1.0, "BT 0"),
|
||||||
|
([15.0, 20.0], 1.0, "CL 0"),
|
||||||
|
([9.0, 25.0], 1.0, "CL 1"),
|
||||||
|
([5.0, 30.0], 1.0, "CL 1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for reg, wgt, str in turn:
|
||||||
|
info_set.regret_sum = reg[:]
|
||||||
|
cur_stra = info_set.get_strat(wgt)
|
||||||
|
avg_stra = info_set.get_avg_strat()
|
||||||
|
|
||||||
|
print(f"{str}")
|
||||||
|
print(f" sum_regret: {reg}")
|
||||||
|
print(f" cur_strategy: [CK/FD: {cur_stra[0]:.3f}, BT/CL: {cur_stra[1]:.3f}]")
|
||||||
|
print(f" avg_strategy: [CK/FD: {avg_stra[0]:.3f}, BT/CL: {avg_stra[1]:.3f}]")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo_strategy()
|
||||||
Reference in New Issue
Block a user