diff --git a/README.md b/README.md index e69de29..24b160a 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,16 @@ + +start:cfr([JQ],"",1,1) + ++ strat_0:J.[0.7,0.3] +0:cfr([JQ],"0",0.7,1) ++ strat_1:Q.0[0.6,0.4] + 0:cfr([JQ],"00",0.7,0.6) -> terminal,get_uitl ++ strat_0:J.01[0.8,0.2] + 1:cfr([JQ],"01",0.7,0.4) + 0:cfr([JQ],"010",0.56,0.4) -> t + 1:cfr([JQ],"011",0.14,0.4) -> t + +1:cfr([JQ],"1",0.3,1) ++ strat_1:Q.1[0.5,0.5] + 0:cfr([JQ],"10",0.3,0.5) -> t + 1:cfr([JQ],"11",0.3,0.5) -> t diff --git a/src/cfr/calcu.py b/src/cfr/calcu.py index 4b2730d..f2cc1ae 100644 --- a/src/cfr/calcu.py +++ b/src/cfr/calcu.py @@ -9,65 +9,116 @@ game = KuhnPoker() def cfr(cards, his, p0, p1): if game.is_terminal(his): - return game.get_profit(cards, his, 0) + return game.get_util(cards, his, 0) player = game.get_cur_player(his) - # + player ? - info_key = game.get_Info_set(cards[player], his, player) + info_key = game.get_Info_set(cards[player], his) if info_key not in info_map: info_map[info_key] = InfoSet(2) # 0-1 info = info_map[info_key] - wgt = p0 if player == 0 else p1 - strat = info.get_strat(wgt) + reach_prob = p0 if player == 0 else p1 + strat = info.get_strat(reach_prob) - act_profit = [0.0, 0.0] - info_profit = 0.0 # 期望 + act_util = [0.0, 0.0] + info_util = 0.0 for action in [0, 1]: next_his = his + str(action) if player == 0: - profit = cfr(cards, next_his, p0*strat[action], p1) + util = cfr(cards, next_his, p0*strat[action], p1) else: - profit = cfr(cards, next_his, p0, p1*strat[action]) + util = cfr(cards, next_his, p0, p1*strat[action]) - act_profit[action] = profit - info_profit += strat[action]*profit + act_util[action] = util + info_util += strat[action]*util - # 更新 for action in [0, 1]: + p0_util = act_util[action] - info_util if player == 0: - regret = act_profit[action] - info_profit + regret = p0_util other_r = p1 else: - regret = -(act_profit[action] - info_profit) + regret = -p0_util other_r = p0 + # 对手的reach probe加权累计regret + # 假设对手到达这的概率小,那么这个regret就不重要 info.regret_sum[action] += other_r * regret - return info_profit + return info_util -def test(): +def test(cnt): cards = ['J', 'Q', 'K'] - p_sum = 0.0 + util_sum = 0.0 - for i in range(10): + for i in range(cnt): card_r = random.sample(cards, 2) + util = cfr(card_r, "", 1.0, 1.0) + util_sum += util - pf = cfr(card_r, "", 1.0, 1.0) - p_sum += pf + if (i + 1) % cnt == 0: + util_avg = util_sum / (i + 1) + sum_regret = 0 + for info in info_map.values(): + max_r = max(info.regret_sum) if info.regret_sum else 0 + sum_regret += max_r + avg_regret = sum_regret / (i + 1) + print(f"range {i+1}/{cnt}, rvg_regret: {avg_regret:.3f}, avg_util: {util_avg:.3f}") + final_util = util_sum / cnt + final_regret = sum(max(0, max(info.regret_sum)) for info in info_map.values()) / cnt + print(f"final_avg_regret: {final_regret:.3f}, final_avg_util: {final_util:.3f}") + return util_sum / cnt + +def all_avg_strat(): + avg_strat = {} + for info_key, info in info_map.items(): + avg_strat[info_key] = info.get_avg_strat() + return avg_strat + +def print_strat(strat): + if strat is None: + strat = all_avg_strat() + + print("="*60) + print(f"{'Info Set':<15} {'Check/Fold':<15} {'Bet/Call':<15}") + print("-"*60) + for info_key in sorted(strat.keys()): + probe = strat[info_key] + parts = info_key.split('.') + card = parts[0] + his = parts[1] if len(parts) > 1 else '' + + act0 = game.get_act_name(his, 0) + act1 = game.get_act_name(his, 1) + + print(f"{info_key:<15} {act0}:{probe[0]:<10.3f} {act1}:{probe[1]:<10.3f}") + print("="*60) + +def reset(): + global info_map + info_map = {} + + + +if __name__ == "__main__": + reset() + test(100000) + print("平均策略") + print_strat(None) + + print("策略check") + strats = all_avg_strat() + print(f" K_Bet_Probe: {strats.get('K.', [0, 0])[1]:.3f} ") + print(f" Q_Check_Probe: {strats.get('Q.', [0, 0])[0]:.3f} ") + print(f" J_Check_Probe: {strats.get('J.', [0, 0])[0]:.3f} ") + - if (i + 1) % 10 == 0: - avg_p = p_sum / (i + 1) - print(f"Range {i+1}/10, Avg : {avg_p:.3f}") - - return p_sum / 10 -def print_strat(): - + \ No newline at end of file diff --git a/src/cfr/info_set.py b/src/cfr/info_set.py index b28f2ff..20dd81e 100644 --- a/src/cfr/info_set.py +++ b/src/cfr/info_set.py @@ -2,35 +2,35 @@ class InfoSet: def __init__(self, act_cnt=2): # 0-1 self.act_cnt = act_cnt - self.regret_sum = [0.0] * act_cnt - self.strat = [0.0] * act_cnt - self.strat_sum = [0.0] * act_cnt + self.regret_sum = [0.0] * act_cnt # 累积regret,用于更新策略 + self.strat = [0.0] * act_cnt + self.strat_sum = [0.0] * act_cnt # 累积策略, 用于平均策略 - def get_strat(self, wgt): + # 更新策略 论文Part3(8) + def get_strat(self, reach_prob): normal = 0.0 for i in range(self.act_cnt): - if self.regret_sum[i] > 0: - self.strat[i] = self.regret_sum[i] - else: - self.strat[i] = 0.0 + self.strat[i] = max(self.regret_sum[i], 0.0) normal += self.strat[i] if normal > 0: for i in range(self.act_cnt): self.strat[i] /= normal else: - ## + # 混合策略 prob = 1.0/self.act_cnt for i in range(self.act_cnt): self.strat[i] = prob - # return - + + # 更新累计策略(拆开) + # 加权累计 for i in range(self.act_cnt): - self.strat_sum[i] += wgt * self.strat[i] + self.strat_sum[i] += reach_prob * self.strat[i] return self.strat + # 论文Part2(4) def get_avg_strat(self): avg_strat = [0.0] * self.act_cnt normal = sum(self.strat_sum) @@ -39,10 +39,8 @@ class InfoSet: for i in range(self.act_cnt): avg_strat[i] = self.strat_sum[i] / normal else: - ## prob = 1.0/self.act_cnt for i in range(self.act_cnt): - avg_strat[i] = prob - # return + avg_strat[i] = prob return avg_strat \ No newline at end of file diff --git a/src/game/kuhn_poker.py b/src/game/kuhn_poker.py index 51356a8..a5fed27 100644 --- a/src/game/kuhn_poker.py +++ b/src/game/kuhn_poker.py @@ -3,11 +3,10 @@ class KuhnPoker: self.cards = ['J', 'Q', 'K'] self.actions = [0, 1] # Check/Fold=0, Bet/Call=1 - # 结构是不是要修改下? def is_terminal(self, history): return history in ['00', '10', '010', '011', '11'] - def get_profit(self, cards, history, player): + def get_util(self, cards, history, player): if not self.is_terminal(history): return 0.0 @@ -42,7 +41,7 @@ class KuhnPoker: return 0.0 - def get_Info_set(self, card, history, player): + def get_Info_set(self, card, history): return f"{card}.{history}" def get_cur_player(self, history): diff --git a/tests/test_game.py b/tests/test_game.py index 73d56e2..a9c48ef 100644 --- a/tests/test_game.py +++ b/tests/test_game.py @@ -22,7 +22,7 @@ def test_game(): (['Q', 'K'], '011', 1), ] for cards, his, player in test_cases: - profit = game.get_profit(cards, his, player) + profit = game.get_util(cards, his, player) print(f"player{player}profit: {profit}") info_cases = [