This commit is contained in:
2025-12-01 17:57:44 +08:00
parent 547118ec6d
commit 8e4be3bda2
3 changed files with 82 additions and 4 deletions

73
src/cfr/calcu.py Normal file
View File

@@ -0,0 +1,73 @@
import random
from src.game.kuhn_poker import KuhnPoker
from src.cfr.info_set import InfoSet
info_map = {}
game = KuhnPoker()
def cfr(cards, his, p0, p1):
if game.is_terminal(his):
return game.get_profit(cards, his, 0)
player = game.get_cur_player(his)
# + player ?
info_key = game.get_Info_set(cards[player], his, player)
if info_key not in info_map:
info_map[info_key] = InfoSet(2) # 0-1
info = info_map[info_key]
wgt = p0 if player == 0 else p1
strat = info.get_strat(wgt)
act_profit = [0.0, 0.0]
info_profit = 0.0 # 期望
for action in [0, 1]:
next_his = his + str(action)
if player == 0:
profit = cfr(cards, next_his, p0*strat[action], p1)
else:
profit = cfr(cards, next_his, p0, p1*strat[action])
act_profit[action] = profit
info_profit += strat[action]*profit
# 更新
for action in [0, 1]:
if player == 0:
regret = act_profit[action] - info_profit
other_r = p1
else:
regret = -(act_profit[action] - info_profit)
other_r = p0
info.regret_sum[action] += other_r * regret
return info_profit
def test():
cards = ['J', 'Q', 'K']
p_sum = 0.0
for i in range(10):
card_r = random.sample(cards, 2)
pf = cfr(card_r, "", 1.0, 1.0)
p_sum += pf
if (i + 1) % 10 == 0:
avg_p = p_sum / (i + 1)
print(f"Range {i+1}/10, Avg : {avg_p:.3f}")
return p_sum / 10
def print_strat():

View File

@@ -21,7 +21,10 @@ class InfoSet:
self.strat[i] /= normal
else:
##
return
prob = 1.0/self.act_cnt
for i in range(self.act_cnt):
self.strat[i] = prob
# return
for i in range(self.act_cnt):
self.strat_sum[i] += wgt * self.strat[i]
@@ -37,8 +40,9 @@ class InfoSet:
avg_strat[i] = self.strat_sum[i] / normal
else:
##
return
prob = 1.0/self.act_cnt
for i in range(self.act_cnt):
avg_strat[i] = prob
# return
return avg_strat

View File

@@ -3,6 +3,7 @@ class KuhnPoker:
self.cards = ['J', 'Q', 'K']
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
# 结构是不是要修改下?
def is_terminal(self, history):
return history in ['00', '10', '010', '011', '11']