diff --git a/README.md b/README.md index e69de29..b39e1d3 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,16 @@ + +start:cfr([JQ],"",1,1) + +# strat_0:J.[0.7,0.3] +0:cfr([JQ],"0",0.7,1) +# strat_1:Q.0[0.6,0.4] + 0:cfr([JQ],"00",0.7,0.6) -> terminal,get_uitl +# strat_0:J.01[0.8,0.2] + 1:cfr([JQ],"01",0.7,0.4) + 0:cfr([JQ],"010",0.56,0.4) -> t + 1:cfr([JQ],"011",0.14,0.4) -> t + +1:cfr([JQ],"1",0.3,1) +# strat_1:Q.1[0.5,0.5] + 0:cfr([JQ],"10",0.3,0.5) -> t + 1:cfr([JQ],"11",0.3,0.5) -> t diff --git a/src/cfr/calcu.py b/src/cfr/calcu.py index 4b2730d..be78e88 100644 --- a/src/cfr/calcu.py +++ b/src/cfr/calcu.py @@ -1,73 +1,143 @@ import random from src.game.kuhn_poker import KuhnPoker from src.cfr.info_set import InfoSet +from src.game.player import Player -info_map = {} game = KuhnPoker() +players = [Player(0), Player(1)] def cfr(cards, his, p0, p1): if game.is_terminal(his): - return game.get_profit(cards, his, 0) + # return { + # 0:game.get_util(cards, his, 0), + # 1:game.get_util(cards, his, 1) + # } + return game.get_util(cards, his, 0) player = game.get_cur_player(his) + if player == -1: + print(f"game over, invialid player") + return 0.0 - # + player ? - info_key = game.get_Info_set(cards[player], his, player) - - if info_key not in info_map: - info_map[info_key] = InfoSet(2) # 0-1 - info = info_map[info_key] - - wgt = p0 if player == 0 else p1 - strat = info.get_strat(wgt) + cur_p = players[player] + cur_p.set_card(cards[player]) + info_key = game.get_Info_set(cards[player], his) + info = cur_p.get_info_set(info_key) - act_profit = [0.0, 0.0] - info_profit = 0.0 # 期望 + strat = info.get_strat(p0 if player == 0 else p1) + + act_util = [0.0, 0.0] + info_util = 0.0 for action in [0, 1]: next_his = his + str(action) if player == 0: - profit = cfr(cards, next_his, p0*strat[action], p1) + util = cfr(cards, next_his, p0*strat[action], p1) else: - profit = cfr(cards, next_his, p0, p1*strat[action]) + util = cfr(cards, next_his, p0, p1*strat[action]) - act_profit[action] = profit - info_profit += strat[action]*profit + act_util[action] = util + info_util += strat[action]*util - # 更新 for action in [0, 1]: + p0_util = act_util[action] - info_util if player == 0: - regret = act_profit[action] - info_profit + regret = p0_util other_r = p1 else: - regret = -(act_profit[action] - info_profit) + regret = -p0_util other_r = p0 + # 对手的reach probe加权累计regret + # 假设对手到达这的概率小,那么这个regret就不重要 info.regret_sum[action] += other_r * regret - return info_profit + return info_util -def test(): +def test(cnt): cards = ['J', 'Q', 'K'] - p_sum = 0.0 + util_sum = 0.0 - for i in range(10): + for i in range(cnt): card_r = random.sample(cards, 2) + util = cfr(card_r, "", 1.0, 1.0) + util_sum += util + players[0].util_sum += util + players[1].util_sum -= util - pf = cfr(card_r, "", 1.0, 1.0) - p_sum += pf + if (i + 1) % cnt == 0: + avg_util_p0 = players[0].util_sum / (i + 1) + avg_util_p1 = players[1].util_sum / (i + 1) + + avg_regret_p0 = players[0].cal_avg_regret(i + 1) + avg_regret_p1 = players[1].cal_avg_regret(i + 1) + + print( + f"cnt: {cnt}\n", + f"avg_util_p0: {avg_util_p0}\n", + f"avg_util_p1: {avg_util_p1}\n", + f"avg_regret_p0: {avg_regret_p0}\n", + f"avg_regret_p1: {avg_regret_p1}\n", + ) + +def all_avg_strat(id): + return players[id].get_avg_strat() + +def print_player_strat(id): + player = players[id] + strategy = player.get_avg_strat() + + print(f"\n{'='*75}") + print(f"玩家 {id} 的平均策略 (信息集数: {len(strategy)})") + print(f"{'='*75}") + print(f"{'info_p':<10} {'card':<10} {'history':<10} {'act0':<20} {'act1':<20}") + print(f"{'-'*75}") + + for info_key in sorted(strategy.keys()): + prob = strategy[info_key] - if (i + 1) % 10 == 0: - avg_p = p_sum / (i + 1) - print(f"Range {i+1}/10, Avg : {avg_p:.3f}") + parts = info_key.split('.') + card = parts[0] + hist = parts[1] if len(parts) > 1 else '' + + act0_name = game.get_act_name(hist, 0) + act1_name = game.get_act_name(hist, 1) + + print(f"{info_key:<10} {card:<10} {hist:<10} " + f"{act0_name}:{prob[0]:<10.3f} {act1_name}:{prob[1]:<10.3f}") - return p_sum / 10 + print(f"{'='*75}\n") -def print_strat(): +def print_all_strategies(): + print("\n" + "#"*75) + print("#" + " "*25 + "所有玩家的平均策略" + " "*24 + "#") + print("#"*75) + for player_id in [0, 1]: + print_player_strat(player_id) +def reset_players(): + for player in players: + player.reset() + +if __name__ == "__main__": + reset_players() + test(1000000) + print("平均策略") + print_player_strat(0) + + print("策略check") + strats = all_avg_strat(0) + print(f" K_Bet_Probe: {strats.get('K.', [0, 0])[1]:.3f} ") + print(f" Q_Check_Probe: {strats.get('Q.', [0, 0])[0]:.3f} ") + print(f" J_Check_Probe: {strats.get('J.', [0, 0])[0]:.3f} ") + + + + + \ No newline at end of file diff --git a/src/cfr/info_set.py b/src/cfr/info_set.py index b28f2ff..20dd81e 100644 --- a/src/cfr/info_set.py +++ b/src/cfr/info_set.py @@ -2,35 +2,35 @@ class InfoSet: def __init__(self, act_cnt=2): # 0-1 self.act_cnt = act_cnt - self.regret_sum = [0.0] * act_cnt - self.strat = [0.0] * act_cnt - self.strat_sum = [0.0] * act_cnt + self.regret_sum = [0.0] * act_cnt # 累积regret,用于更新策略 + self.strat = [0.0] * act_cnt + self.strat_sum = [0.0] * act_cnt # 累积策略, 用于平均策略 - def get_strat(self, wgt): + # 更新策略 论文Part3(8) + def get_strat(self, reach_prob): normal = 0.0 for i in range(self.act_cnt): - if self.regret_sum[i] > 0: - self.strat[i] = self.regret_sum[i] - else: - self.strat[i] = 0.0 + self.strat[i] = max(self.regret_sum[i], 0.0) normal += self.strat[i] if normal > 0: for i in range(self.act_cnt): self.strat[i] /= normal else: - ## + # 混合策略 prob = 1.0/self.act_cnt for i in range(self.act_cnt): self.strat[i] = prob - # return - + + # 更新累计策略(拆开) + # 加权累计 for i in range(self.act_cnt): - self.strat_sum[i] += wgt * self.strat[i] + self.strat_sum[i] += reach_prob * self.strat[i] return self.strat + # 论文Part2(4) def get_avg_strat(self): avg_strat = [0.0] * self.act_cnt normal = sum(self.strat_sum) @@ -39,10 +39,8 @@ class InfoSet: for i in range(self.act_cnt): avg_strat[i] = self.strat_sum[i] / normal else: - ## prob = 1.0/self.act_cnt for i in range(self.act_cnt): - avg_strat[i] = prob - # return + avg_strat[i] = prob return avg_strat \ No newline at end of file diff --git a/src/game/kuhn_poker.py b/src/game/kuhn_poker.py index 51356a8..b327dbf 100644 --- a/src/game/kuhn_poker.py +++ b/src/game/kuhn_poker.py @@ -1,52 +1,49 @@ class KuhnPoker: def __init__(self): self.cards = ['J', 'Q', 'K'] - self.actions = [0, 1] # Check/Fold=0, Bet/Call=1 + # self.actions = [0, 1] # Check/Fold=0, Bet/Call=1 + self.player_cnt = 2 - # 结构是不是要修改下? def is_terminal(self, history): return history in ['00', '10', '010', '011', '11'] - def get_profit(self, cards, history, player): + def get_util(self, cards, history, player): if not self.is_terminal(history): - return 0.0 + print(f"game not over, invalid util") + print("*"*30) card_values = {'J': 0, 'Q': 1, 'K': 2} p0_cardv = card_values[cards[0]] p1_cardv = card_values[cards[1]] p0_wins = p0_cardv > p1_cardv + p0_util = 0 if history == '00': - if p0_wins: - return 1.0 if player == 0 else -1.0 - else: - return -1.0 if player == 0 else 1.0 - + p0_util = 1 if p0_wins else -1 elif history == '10': - return 1.0 if player == 0 else -1.0 - - elif history == '010': - return -1.0 if player == 0 else 1.0 - - elif history == '011': - if p0_wins: - return 2.0 if player == 0 else -2.0 - else: - return -2.0 if player == 0 else 2.0 - + p0_util = 1 if player == 0 else -1 elif history == '11': - if p0_wins: - return 2.0 if player == 0 else -2.0 - else: - return -2.0 if player == 0 else 2.0 + p0_util = 2 if p0_wins else -2 + elif history == '010': + p0_util = -1 if p0_wins else 1 + elif history == '011': + p0_util = 2 if p0_wins else -2 - return 0.0 + if player == 0: + return p0_util + elif player == 1: + return -p0_util + else: + print(f"invalid player:{player}") + print("*"*30) + return 0 - def get_Info_set(self, card, history, player): + def get_Info_set(self, card, history): return f"{card}.{history}" def get_cur_player(self, history): if self.is_terminal(history): + print("game over, no player act") return -1 return len(history) % 2 diff --git a/src/game/player.py b/src/game/player.py new file mode 100644 index 0000000..72a2841 --- /dev/null +++ b/src/game/player.py @@ -0,0 +1,44 @@ +from src.cfr.info_set import InfoSet +class Player: + def __init__(self, id): + self.id = id + self.card = None + self.info_map = {} + self.regret_sum = 0 # 累计regret + self.strat_sum = 0 # 累计strat + self.util_sum = 0 + + def set_card(self, card): + self.card = card + + def get_info_set(self, info_key, act_cnt=2): + if info_key not in self.info_map: + self.info_map[info_key] = InfoSet(act_cnt) + return self.info_map[info_key] + + def get_avg_strat(self): + return { + info_key: info.get_avg_strat() + for info_key, info in self.info_map.items() + } + + + def cal_avg_regret(self, cnt): + if cnt <= 0: + print("invalid test count") + total = sum(max(0, max(info.regret_sum)) for info in self.info_map.values()) + return total / cnt + + def reset(self): + self.info_map = {} + self.card = None + self.regret_sum = 0 + self.strat_sum = 0 + self.util_sum = 0 + + def __str__(self): + return f"player:{self.id} , card:{self.card}, info_sets:{len(self.info_map)}" + + def print_info(self): + for info_key, info in self.info_map.items(): + print(f"infoset: '{info_key}', info: {info}") \ No newline at end of file diff --git a/tests/test_game.py b/tests/test_game.py index 73d56e2..dfb8206 100644 --- a/tests/test_game.py +++ b/tests/test_game.py @@ -3,6 +3,8 @@ import os sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from src.game.kuhn_poker import KuhnPoker +from src.cfr.info_set import InfoSet +from src.game.player import Player def test_game(): game = KuhnPoker() @@ -22,30 +24,53 @@ def test_game(): (['Q', 'K'], '011', 1), ] for cards, his, player in test_cases: - profit = game.get_profit(cards, his, player) + profit = game.get_util(cards, his, player) print(f"player{player}profit: {profit}") info_cases = [ - ('K', '0', 0), - ('J', '01', 1), - ('Q', '', 0), + ('JQ', '' , 0), ('JK', '' , 0), + ('JQ', '0', 0), ('JK', '0', 0), + ('JQ', '1', 1), ('JK', '1', 1), + ('JQ', '00', 0), ('JK', '00', 0), + ('JQ', '10', 0), ('KQ', '10', 0), + ('QK', '01', 1), ('JK', '01', 1), + ('JQ', '010',0), ('JK', '010',0), + ('JQ', '011',1), ('JK', '011',1), ] - - for card, his, player in info_cases: - info_set = game.get_Info_set(card, his, player) - print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'") - - print(f"valid act: {game.get_valid_act('00')}") - - test_seq = ['00', '10', '01', '010', '011', '11'] - for seq in test_seq: - terminal = game.is_terminal(seq) - if terminal: - print(f" '{seq}' terminal") - else: - current_player = game.get_cur_player(seq) - print(f" '{seq}' not terminal, cur player: {current_player}") + info_map = {} + players = [Player(0), Player(1)] + + for cards, his, id in info_cases: + player = players[id] + cards = [cards[0], cards[1]] + info_key = game.get_Info_set(cards[id], his) + player.get_info_set(info_key) + for p in players: + print(f"player {p.id} has {len(p.info_map)} info sets.") + for info_key, info in p.info_map.items(): + info_map[info_key] = info + print("infokey:", info_key) + + + # for info_key, info in info_map.items(): + # strat = info.get_strat(1.0) + # avg_strat = info.get_avg_strat() + # print(f"infoset: '{info_key}'") + # print(f" current strategy: [CK/FD: {strat[0]:.3f}, BT/CL: {strat[1]:.3f}]") + # print(f" average strategy: [CK/FD: {avg_strat[0]:.3f}, BT/CL: {avg_strat[1]:.3f}]") + + # print('=='*20) + + # test_seq = ['00', '10', '01', '010', '011', '11'] + + # for seq in test_seq: + # terminal = game.is_terminal(seq) + # if terminal: + # print(f" '{seq}' terminal") + # else: + # current_player = game.get_cur_player(seq) + # print(f" '{seq}' not terminal, cur player: {current_player}") if __name__ == "__main__": test_game() \ No newline at end of file