cfr fix
This commit is contained in:
73
src/cfr/calcu.py
Normal file
73
src/cfr/calcu.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import random
|
||||||
|
from src.game.kuhn_poker import KuhnPoker
|
||||||
|
from src.cfr.info_set import InfoSet
|
||||||
|
|
||||||
|
|
||||||
|
info_map = {}
|
||||||
|
game = KuhnPoker()
|
||||||
|
|
||||||
|
def cfr(cards, his, p0, p1):
|
||||||
|
|
||||||
|
if game.is_terminal(his):
|
||||||
|
return game.get_profit(cards, his, 0)
|
||||||
|
|
||||||
|
player = game.get_cur_player(his)
|
||||||
|
|
||||||
|
# + player ?
|
||||||
|
info_key = game.get_Info_set(cards[player], his, player)
|
||||||
|
|
||||||
|
if info_key not in info_map:
|
||||||
|
info_map[info_key] = InfoSet(2) # 0-1
|
||||||
|
info = info_map[info_key]
|
||||||
|
|
||||||
|
wgt = p0 if player == 0 else p1
|
||||||
|
strat = info.get_strat(wgt)
|
||||||
|
|
||||||
|
|
||||||
|
act_profit = [0.0, 0.0]
|
||||||
|
info_profit = 0.0 # 期望
|
||||||
|
|
||||||
|
for action in [0, 1]:
|
||||||
|
next_his = his + str(action)
|
||||||
|
|
||||||
|
if player == 0:
|
||||||
|
profit = cfr(cards, next_his, p0*strat[action], p1)
|
||||||
|
else:
|
||||||
|
profit = cfr(cards, next_his, p0, p1*strat[action])
|
||||||
|
|
||||||
|
act_profit[action] = profit
|
||||||
|
info_profit += strat[action]*profit
|
||||||
|
|
||||||
|
# 更新
|
||||||
|
for action in [0, 1]:
|
||||||
|
if player == 0:
|
||||||
|
regret = act_profit[action] - info_profit
|
||||||
|
other_r = p1
|
||||||
|
else:
|
||||||
|
regret = -(act_profit[action] - info_profit)
|
||||||
|
other_r = p0
|
||||||
|
info.regret_sum[action] += other_r * regret
|
||||||
|
|
||||||
|
return info_profit
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
|
||||||
|
cards = ['J', 'Q', 'K']
|
||||||
|
p_sum = 0.0
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
card_r = random.sample(cards, 2)
|
||||||
|
|
||||||
|
pf = cfr(card_r, "", 1.0, 1.0)
|
||||||
|
p_sum += pf
|
||||||
|
|
||||||
|
if (i + 1) % 10 == 0:
|
||||||
|
avg_p = p_sum / (i + 1)
|
||||||
|
print(f"Range {i+1}/10, Avg : {avg_p:.3f}")
|
||||||
|
|
||||||
|
return p_sum / 10
|
||||||
|
|
||||||
|
def print_strat():
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +21,10 @@ class InfoSet:
|
|||||||
self.strat[i] /= normal
|
self.strat[i] /= normal
|
||||||
else:
|
else:
|
||||||
##
|
##
|
||||||
return
|
prob = 1.0/self.act_cnt
|
||||||
|
for i in range(self.act_cnt):
|
||||||
|
self.strat[i] = prob
|
||||||
|
# return
|
||||||
|
|
||||||
for i in range(self.act_cnt):
|
for i in range(self.act_cnt):
|
||||||
self.strat_sum[i] += wgt * self.strat[i]
|
self.strat_sum[i] += wgt * self.strat[i]
|
||||||
@@ -37,8 +40,9 @@ class InfoSet:
|
|||||||
avg_strat[i] = self.strat_sum[i] / normal
|
avg_strat[i] = self.strat_sum[i] / normal
|
||||||
else:
|
else:
|
||||||
##
|
##
|
||||||
return
|
prob = 1.0/self.act_cnt
|
||||||
|
for i in range(self.act_cnt):
|
||||||
|
avg_strat[i] = prob
|
||||||
|
# return
|
||||||
return avg_strat
|
return avg_strat
|
||||||
|
|
||||||
@@ -3,6 +3,7 @@ class KuhnPoker:
|
|||||||
self.cards = ['J', 'Q', 'K']
|
self.cards = ['J', 'Q', 'K']
|
||||||
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
|
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
|
||||||
|
|
||||||
|
# 结构是不是要修改下?
|
||||||
def is_terminal(self, history):
|
def is_terminal(self, history):
|
||||||
return history in ['00', '10', '010', '011', '11']
|
return history in ['00', '10', '010', '011', '11']
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user