diff --git a/src/cfr/calcu.py b/src/cfr/calcu.py new file mode 100644 index 0000000..4b2730d --- /dev/null +++ b/src/cfr/calcu.py @@ -0,0 +1,73 @@ +import random +from src.game.kuhn_poker import KuhnPoker +from src.cfr.info_set import InfoSet + + +info_map = {} +game = KuhnPoker() + +def cfr(cards, his, p0, p1): + + if game.is_terminal(his): + return game.get_profit(cards, his, 0) + + player = game.get_cur_player(his) + + # + player ? + info_key = game.get_Info_set(cards[player], his, player) + + if info_key not in info_map: + info_map[info_key] = InfoSet(2) # 0-1 + info = info_map[info_key] + + wgt = p0 if player == 0 else p1 + strat = info.get_strat(wgt) + + + act_profit = [0.0, 0.0] + info_profit = 0.0 # 期望 + + for action in [0, 1]: + next_his = his + str(action) + + if player == 0: + profit = cfr(cards, next_his, p0*strat[action], p1) + else: + profit = cfr(cards, next_his, p0, p1*strat[action]) + + act_profit[action] = profit + info_profit += strat[action]*profit + + # 更新 + for action in [0, 1]: + if player == 0: + regret = act_profit[action] - info_profit + other_r = p1 + else: + regret = -(act_profit[action] - info_profit) + other_r = p0 + info.regret_sum[action] += other_r * regret + + return info_profit + + +def test(): + + cards = ['J', 'Q', 'K'] + p_sum = 0.0 + + for i in range(10): + card_r = random.sample(cards, 2) + + pf = cfr(card_r, "", 1.0, 1.0) + p_sum += pf + + if (i + 1) % 10 == 0: + avg_p = p_sum / (i + 1) + print(f"Range {i+1}/10, Avg : {avg_p:.3f}") + + return p_sum / 10 + +def print_strat(): + + diff --git a/src/cfr/info_set.py b/src/cfr/info_set.py index 4bb8548..b28f2ff 100644 --- a/src/cfr/info_set.py +++ b/src/cfr/info_set.py @@ -21,7 +21,10 @@ class InfoSet: self.strat[i] /= normal else: ## - return + prob = 1.0/self.act_cnt + for i in range(self.act_cnt): + self.strat[i] = prob + # return for i in range(self.act_cnt): self.strat_sum[i] += wgt * self.strat[i] @@ -37,8 +40,9 @@ class InfoSet: avg_strat[i] = self.strat_sum[i] / normal else: ## - return - - + prob = 1.0/self.act_cnt + for i in range(self.act_cnt): + avg_strat[i] = prob + # return return avg_strat \ No newline at end of file diff --git a/src/game/kuhn_poker.py b/src/game/kuhn_poker.py index f6c2074..51356a8 100644 --- a/src/game/kuhn_poker.py +++ b/src/game/kuhn_poker.py @@ -3,6 +3,7 @@ class KuhnPoker: self.cards = ['J', 'Q', 'K'] self.actions = [0, 1] # Check/Fold=0, Bet/Call=1 + # 结构是不是要修改下? def is_terminal(self, history): return history in ['00', '10', '010', '011', '11']