commit 547118ec6d5fdf9c714549ea68d2d7867d6d827d Author: jianghaiying Date: Fri Nov 28 17:19:59 2025 +0800 kubn cfr:0 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..505a3b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# Virtual environments +.venv diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..65632bf --- /dev/null +++ b/main.py @@ -0,0 +1,13 @@ +from src.game.kuhn_poker import KuhnPoker +from src.cfr.info_set import InfoSet + + +def main(): + game = KuhnPoker() + + info_set = InfoSet(act_cnt=2) + strategy = info_set.get_strat(.0) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5b85e3a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "kubn" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [] diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cfr/__init__.py b/src/cfr/__init__.py new file mode 100644 index 0000000..c6479b2 --- /dev/null +++ b/src/cfr/__init__.py @@ -0,0 +1,3 @@ +from .info_set import InfoSet + +__all__ = ['InfoSet'] diff --git a/src/cfr/info_set.py b/src/cfr/info_set.py new file mode 100644 index 0000000..4bb8548 --- /dev/null +++ b/src/cfr/info_set.py @@ -0,0 +1,44 @@ +class InfoSet: + def __init__(self, act_cnt=2): # 0-1 + self.act_cnt = act_cnt + + self.regret_sum = [0.0] * act_cnt + self.strat = [0.0] * act_cnt + self.strat_sum = [0.0] * act_cnt + + def get_strat(self, wgt): + + normal = 0.0 + for i in range(self.act_cnt): + if self.regret_sum[i] > 0: + self.strat[i] = self.regret_sum[i] + else: + self.strat[i] = 0.0 + normal += self.strat[i] + + if normal > 0: + for i in range(self.act_cnt): + self.strat[i] /= normal + else: + ## + return + + for i in range(self.act_cnt): + self.strat_sum[i] += wgt * self.strat[i] + + return self.strat + + def get_avg_strat(self): + avg_strat = [0.0] * self.act_cnt + normal = sum(self.strat_sum) + + if normal > 0: + for i in range(self.act_cnt): + avg_strat[i] = self.strat_sum[i] / normal + else: + ## + return + + + return avg_strat + \ No newline at end of file diff --git a/src/game/__init__.py b/src/game/__init__.py new file mode 100644 index 0000000..e2a9c48 --- /dev/null +++ b/src/game/__init__.py @@ -0,0 +1,3 @@ +from .kuhn_poker import KuhnPoker + +__all__ = ['KuhnPoker'] diff --git a/src/game/kuhn_poker.py b/src/game/kuhn_poker.py new file mode 100644 index 0000000..f6c2074 --- /dev/null +++ b/src/game/kuhn_poker.py @@ -0,0 +1,61 @@ +class KuhnPoker: + def __init__(self): + self.cards = ['J', 'Q', 'K'] + self.actions = [0, 1] # Check/Fold=0, Bet/Call=1 + + def is_terminal(self, history): + return history in ['00', '10', '010', '011', '11'] + + def get_profit(self, cards, history, player): + if not self.is_terminal(history): + return 0.0 + + card_values = {'J': 0, 'Q': 1, 'K': 2} + p0_cardv = card_values[cards[0]] + p1_cardv = card_values[cards[1]] + + p0_wins = p0_cardv > p1_cardv + if history == '00': + if p0_wins: + return 1.0 if player == 0 else -1.0 + else: + return -1.0 if player == 0 else 1.0 + + elif history == '10': + return 1.0 if player == 0 else -1.0 + + elif history == '010': + return -1.0 if player == 0 else 1.0 + + elif history == '011': + if p0_wins: + return 2.0 if player == 0 else -2.0 + else: + return -2.0 if player == 0 else 2.0 + + elif history == '11': + if p0_wins: + return 2.0 if player == 0 else -2.0 + else: + return -2.0 if player == 0 else 2.0 + + return 0.0 + + def get_Info_set(self, card, history, player): + return f"{card}.{history}" + + def get_cur_player(self, history): + if self.is_terminal(history): + return -1 + return len(history) % 2 + + def get_valid_act(self, history): + if self.is_terminal(history): + return [] + return [0, 1] + + def get_act_name(self, history, action): + if history == "01": + return "Fold" if action == 0 else "Call" + else: + return "Check" if action == 0 else "Bet" \ No newline at end of file diff --git a/tests/test_game.py b/tests/test_game.py new file mode 100644 index 0000000..73d56e2 --- /dev/null +++ b/tests/test_game.py @@ -0,0 +1,51 @@ +import sys +import os +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +from src.game.kuhn_poker import KuhnPoker + +def test_game(): + game = KuhnPoker() + test_histories = ['', '0', '1', '00', '10', '01', '011', '010', '11'] + for his in test_histories: + terminal = game.is_terminal(his) + player = game.get_cur_player(his) + print(f"'{his}' - teminal: {terminal}, player: {player}") + + test_cases = [ + (['K', 'J'], '00', 0), + (['J', 'K'], '00', 0), + (['Q', 'J'], '10', 0), + (['J', 'Q'], '010', 0), + (['K', 'J'], '11', 0), + (['J', 'Q'], '11', 0), + (['Q', 'K'], '011', 1), + ] + for cards, his, player in test_cases: + profit = game.get_profit(cards, his, player) + print(f"player{player}profit: {profit}") + + info_cases = [ + ('K', '0', 0), + ('J', '01', 1), + ('Q', '', 0), + ] + + for card, his, player in info_cases: + info_set = game.get_Info_set(card, his, player) + print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'") + + print(f"valid act: {game.get_valid_act('00')}") + + test_seq = ['00', '10', '01', '010', '011', '11'] + + for seq in test_seq: + terminal = game.is_terminal(seq) + if terminal: + print(f" '{seq}' terminal") + else: + current_player = game.get_cur_player(seq) + print(f" '{seq}' not terminal, cur player: {current_player}") + +if __name__ == "__main__": + test_game() \ No newline at end of file diff --git a/tests/test_strat.py b/tests/test_strat.py new file mode 100644 index 0000000..f6a77f9 --- /dev/null +++ b/tests/test_strat.py @@ -0,0 +1,30 @@ +import sys +import os +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +from src.cfr.info_set import InfoSet + +def demo_strategy(): + + info_set = InfoSet(act_cnt=2) + + turn = [ + ([10.0, 5.0], 1.0, "BT 0"), + ([15.0, 20.0], 1.0, "CL 0"), + ([9.0, 25.0], 1.0, "CL 1"), + ([5.0, 30.0], 1.0, "CL 1"), + ] + + for reg, wgt, str in turn: + info_set.regret_sum = reg[:] + cur_stra = info_set.get_strat(wgt) + avg_stra = info_set.get_avg_strat() + + print(f"{str}") + print(f" sum_regret: {reg}") + print(f" cur_strategy: [CK/FD: {cur_stra[0]:.3f}, BT/CL: {cur_stra[1]:.3f}]") + print(f" avg_strategy: [CK/FD: {avg_stra[0]:.3f}, BT/CL: {avg_stra[1]:.3f}]") + print() + +if __name__ == "__main__": + demo_strategy() \ No newline at end of file