cfr fix
This commit is contained in:
73
src/cfr/calcu.py
Normal file
73
src/cfr/calcu.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import random
|
||||
from src.game.kuhn_poker import KuhnPoker
|
||||
from src.cfr.info_set import InfoSet
|
||||
|
||||
|
||||
info_map = {}
|
||||
game = KuhnPoker()
|
||||
|
||||
def cfr(cards, his, p0, p1):
|
||||
|
||||
if game.is_terminal(his):
|
||||
return game.get_profit(cards, his, 0)
|
||||
|
||||
player = game.get_cur_player(his)
|
||||
|
||||
# + player ?
|
||||
info_key = game.get_Info_set(cards[player], his, player)
|
||||
|
||||
if info_key not in info_map:
|
||||
info_map[info_key] = InfoSet(2) # 0-1
|
||||
info = info_map[info_key]
|
||||
|
||||
wgt = p0 if player == 0 else p1
|
||||
strat = info.get_strat(wgt)
|
||||
|
||||
|
||||
act_profit = [0.0, 0.0]
|
||||
info_profit = 0.0 # 期望
|
||||
|
||||
for action in [0, 1]:
|
||||
next_his = his + str(action)
|
||||
|
||||
if player == 0:
|
||||
profit = cfr(cards, next_his, p0*strat[action], p1)
|
||||
else:
|
||||
profit = cfr(cards, next_his, p0, p1*strat[action])
|
||||
|
||||
act_profit[action] = profit
|
||||
info_profit += strat[action]*profit
|
||||
|
||||
# 更新
|
||||
for action in [0, 1]:
|
||||
if player == 0:
|
||||
regret = act_profit[action] - info_profit
|
||||
other_r = p1
|
||||
else:
|
||||
regret = -(act_profit[action] - info_profit)
|
||||
other_r = p0
|
||||
info.regret_sum[action] += other_r * regret
|
||||
|
||||
return info_profit
|
||||
|
||||
|
||||
def test():
|
||||
|
||||
cards = ['J', 'Q', 'K']
|
||||
p_sum = 0.0
|
||||
|
||||
for i in range(10):
|
||||
card_r = random.sample(cards, 2)
|
||||
|
||||
pf = cfr(card_r, "", 1.0, 1.0)
|
||||
p_sum += pf
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
avg_p = p_sum / (i + 1)
|
||||
print(f"Range {i+1}/10, Avg : {avg_p:.3f}")
|
||||
|
||||
return p_sum / 10
|
||||
|
||||
def print_strat():
|
||||
|
||||
|
||||
@@ -21,7 +21,10 @@ class InfoSet:
|
||||
self.strat[i] /= normal
|
||||
else:
|
||||
##
|
||||
return
|
||||
prob = 1.0/self.act_cnt
|
||||
for i in range(self.act_cnt):
|
||||
self.strat[i] = prob
|
||||
# return
|
||||
|
||||
for i in range(self.act_cnt):
|
||||
self.strat_sum[i] += wgt * self.strat[i]
|
||||
@@ -37,8 +40,9 @@ class InfoSet:
|
||||
avg_strat[i] = self.strat_sum[i] / normal
|
||||
else:
|
||||
##
|
||||
return
|
||||
|
||||
|
||||
prob = 1.0/self.act_cnt
|
||||
for i in range(self.act_cnt):
|
||||
avg_strat[i] = prob
|
||||
# return
|
||||
return avg_strat
|
||||
|
||||
@@ -3,6 +3,7 @@ class KuhnPoker:
|
||||
self.cards = ['J', 'Q', 'K']
|
||||
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
|
||||
|
||||
# 结构是不是要修改下?
|
||||
def is_terminal(self, history):
|
||||
return history in ['00', '10', '010', '011', '11']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user