cfr:test
This commit is contained in:
16
README.md
16
README.md
@@ -0,0 +1,16 @@
|
|||||||
|
|
||||||
|
start:cfr([JQ],"",1,1)
|
||||||
|
|
||||||
|
# strat_0:J.[0.7,0.3]
|
||||||
|
0:cfr([JQ],"0",0.7,1)
|
||||||
|
# strat_1:Q.0[0.6,0.4]
|
||||||
|
0:cfr([JQ],"00",0.7,0.6) -> terminal,get_uitl
|
||||||
|
# strat_0:J.01[0.8,0.2]
|
||||||
|
1:cfr([JQ],"01",0.7,0.4)
|
||||||
|
0:cfr([JQ],"010",0.56,0.4) -> t
|
||||||
|
1:cfr([JQ],"011",0.14,0.4) -> t
|
||||||
|
|
||||||
|
1:cfr([JQ],"1",0.3,1)
|
||||||
|
# strat_1:Q.1[0.5,0.5]
|
||||||
|
0:cfr([JQ],"10",0.3,0.5) -> t
|
||||||
|
1:cfr([JQ],"11",0.3,0.5) -> t
|
||||||
|
|||||||
130
src/cfr/calcu.py
130
src/cfr/calcu.py
@@ -1,73 +1,143 @@
|
|||||||
import random
|
import random
|
||||||
from src.game.kuhn_poker import KuhnPoker
|
from src.game.kuhn_poker import KuhnPoker
|
||||||
from src.cfr.info_set import InfoSet
|
from src.cfr.info_set import InfoSet
|
||||||
|
from src.game.player import Player
|
||||||
|
|
||||||
|
|
||||||
info_map = {}
|
|
||||||
game = KuhnPoker()
|
game = KuhnPoker()
|
||||||
|
players = [Player(0), Player(1)]
|
||||||
|
|
||||||
def cfr(cards, his, p0, p1):
|
def cfr(cards, his, p0, p1):
|
||||||
|
|
||||||
if game.is_terminal(his):
|
if game.is_terminal(his):
|
||||||
return game.get_profit(cards, his, 0)
|
# return {
|
||||||
|
# 0:game.get_util(cards, his, 0),
|
||||||
|
# 1:game.get_util(cards, his, 1)
|
||||||
|
# }
|
||||||
|
return game.get_util(cards, his, 0)
|
||||||
|
|
||||||
player = game.get_cur_player(his)
|
player = game.get_cur_player(his)
|
||||||
|
if player == -1:
|
||||||
|
print(f"game over, invialid player")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
# + player ?
|
cur_p = players[player]
|
||||||
info_key = game.get_Info_set(cards[player], his, player)
|
cur_p.set_card(cards[player])
|
||||||
|
|
||||||
if info_key not in info_map:
|
info_key = game.get_Info_set(cards[player], his)
|
||||||
info_map[info_key] = InfoSet(2) # 0-1
|
info = cur_p.get_info_set(info_key)
|
||||||
info = info_map[info_key]
|
|
||||||
|
|
||||||
wgt = p0 if player == 0 else p1
|
strat = info.get_strat(p0 if player == 0 else p1)
|
||||||
strat = info.get_strat(wgt)
|
|
||||||
|
|
||||||
|
act_util = [0.0, 0.0]
|
||||||
act_profit = [0.0, 0.0]
|
info_util = 0.0
|
||||||
info_profit = 0.0 # 期望
|
|
||||||
|
|
||||||
for action in [0, 1]:
|
for action in [0, 1]:
|
||||||
next_his = his + str(action)
|
next_his = his + str(action)
|
||||||
|
|
||||||
if player == 0:
|
if player == 0:
|
||||||
profit = cfr(cards, next_his, p0*strat[action], p1)
|
util = cfr(cards, next_his, p0*strat[action], p1)
|
||||||
else:
|
else:
|
||||||
profit = cfr(cards, next_his, p0, p1*strat[action])
|
util = cfr(cards, next_his, p0, p1*strat[action])
|
||||||
|
|
||||||
act_profit[action] = profit
|
act_util[action] = util
|
||||||
info_profit += strat[action]*profit
|
info_util += strat[action]*util
|
||||||
|
|
||||||
# 更新
|
|
||||||
for action in [0, 1]:
|
for action in [0, 1]:
|
||||||
|
p0_util = act_util[action] - info_util
|
||||||
if player == 0:
|
if player == 0:
|
||||||
regret = act_profit[action] - info_profit
|
regret = p0_util
|
||||||
other_r = p1
|
other_r = p1
|
||||||
else:
|
else:
|
||||||
regret = -(act_profit[action] - info_profit)
|
regret = -p0_util
|
||||||
other_r = p0
|
other_r = p0
|
||||||
|
# 对手的reach probe加权累计regret
|
||||||
|
# 假设对手到达这的概率小,那么这个regret就不重要
|
||||||
info.regret_sum[action] += other_r * regret
|
info.regret_sum[action] += other_r * regret
|
||||||
|
|
||||||
return info_profit
|
return info_util
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test(cnt):
|
||||||
|
|
||||||
cards = ['J', 'Q', 'K']
|
cards = ['J', 'Q', 'K']
|
||||||
p_sum = 0.0
|
util_sum = 0.0
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(cnt):
|
||||||
card_r = random.sample(cards, 2)
|
card_r = random.sample(cards, 2)
|
||||||
|
util = cfr(card_r, "", 1.0, 1.0)
|
||||||
|
util_sum += util
|
||||||
|
players[0].util_sum += util
|
||||||
|
players[1].util_sum -= util
|
||||||
|
|
||||||
pf = cfr(card_r, "", 1.0, 1.0)
|
if (i + 1) % cnt == 0:
|
||||||
p_sum += pf
|
avg_util_p0 = players[0].util_sum / (i + 1)
|
||||||
|
avg_util_p1 = players[1].util_sum / (i + 1)
|
||||||
|
|
||||||
|
avg_regret_p0 = players[0].cal_avg_regret(i + 1)
|
||||||
|
avg_regret_p1 = players[1].cal_avg_regret(i + 1)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"cnt: {cnt}\n",
|
||||||
|
f"avg_util_p0: {avg_util_p0}\n",
|
||||||
|
f"avg_util_p1: {avg_util_p1}\n",
|
||||||
|
f"avg_regret_p0: {avg_regret_p0}\n",
|
||||||
|
f"avg_regret_p1: {avg_regret_p1}\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
def all_avg_strat(id):
|
||||||
|
return players[id].get_avg_strat()
|
||||||
|
|
||||||
|
def print_player_strat(id):
|
||||||
|
player = players[id]
|
||||||
|
strategy = player.get_avg_strat()
|
||||||
|
|
||||||
|
print(f"\n{'='*75}")
|
||||||
|
print(f"玩家 {id} 的平均策略 (信息集数: {len(strategy)})")
|
||||||
|
print(f"{'='*75}")
|
||||||
|
print(f"{'info_p':<10} {'card':<10} {'history':<10} {'act0':<20} {'act1':<20}")
|
||||||
|
print(f"{'-'*75}")
|
||||||
|
|
||||||
|
for info_key in sorted(strategy.keys()):
|
||||||
|
prob = strategy[info_key]
|
||||||
|
|
||||||
|
parts = info_key.split('.')
|
||||||
|
card = parts[0]
|
||||||
|
hist = parts[1] if len(parts) > 1 else ''
|
||||||
|
|
||||||
|
act0_name = game.get_act_name(hist, 0)
|
||||||
|
act1_name = game.get_act_name(hist, 1)
|
||||||
|
|
||||||
|
print(f"{info_key:<10} {card:<10} {hist:<10} "
|
||||||
|
f"{act0_name}:{prob[0]:<10.3f} {act1_name}:{prob[1]:<10.3f}")
|
||||||
|
|
||||||
|
print(f"{'='*75}\n")
|
||||||
|
|
||||||
|
def print_all_strategies():
|
||||||
|
print("\n" + "#"*75)
|
||||||
|
print("#" + " "*25 + "所有玩家的平均策略" + " "*24 + "#")
|
||||||
|
print("#"*75)
|
||||||
|
|
||||||
|
for player_id in [0, 1]:
|
||||||
|
print_player_strat(player_id)
|
||||||
|
|
||||||
|
def reset_players():
|
||||||
|
for player in players:
|
||||||
|
player.reset()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
reset_players()
|
||||||
|
test(1000000)
|
||||||
|
print("平均策略")
|
||||||
|
print_player_strat(0)
|
||||||
|
|
||||||
|
print("策略check")
|
||||||
|
strats = all_avg_strat(0)
|
||||||
|
print(f" K_Bet_Probe: {strats.get('K.', [0, 0])[1]:.3f} ")
|
||||||
|
print(f" Q_Check_Probe: {strats.get('Q.', [0, 0])[0]:.3f} ")
|
||||||
|
print(f" J_Check_Probe: {strats.get('J.', [0, 0])[0]:.3f} ")
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
|
||||||
avg_p = p_sum / (i + 1)
|
|
||||||
print(f"Range {i+1}/10, Avg : {avg_p:.3f}")
|
|
||||||
|
|
||||||
return p_sum / 10
|
|
||||||
|
|
||||||
def print_strat():
|
|
||||||
|
|
||||||
|
|
||||||
@@ -2,35 +2,35 @@ class InfoSet:
|
|||||||
def __init__(self, act_cnt=2): # 0-1
|
def __init__(self, act_cnt=2): # 0-1
|
||||||
self.act_cnt = act_cnt
|
self.act_cnt = act_cnt
|
||||||
|
|
||||||
self.regret_sum = [0.0] * act_cnt
|
self.regret_sum = [0.0] * act_cnt # 累积regret,用于更新策略
|
||||||
self.strat = [0.0] * act_cnt
|
self.strat = [0.0] * act_cnt
|
||||||
self.strat_sum = [0.0] * act_cnt
|
self.strat_sum = [0.0] * act_cnt # 累积策略, 用于平均策略
|
||||||
|
|
||||||
def get_strat(self, wgt):
|
# 更新策略 论文Part3(8)
|
||||||
|
def get_strat(self, reach_prob):
|
||||||
|
|
||||||
normal = 0.0
|
normal = 0.0
|
||||||
for i in range(self.act_cnt):
|
for i in range(self.act_cnt):
|
||||||
if self.regret_sum[i] > 0:
|
self.strat[i] = max(self.regret_sum[i], 0.0)
|
||||||
self.strat[i] = self.regret_sum[i]
|
|
||||||
else:
|
|
||||||
self.strat[i] = 0.0
|
|
||||||
normal += self.strat[i]
|
normal += self.strat[i]
|
||||||
|
|
||||||
if normal > 0:
|
if normal > 0:
|
||||||
for i in range(self.act_cnt):
|
for i in range(self.act_cnt):
|
||||||
self.strat[i] /= normal
|
self.strat[i] /= normal
|
||||||
else:
|
else:
|
||||||
##
|
# 混合策略
|
||||||
prob = 1.0/self.act_cnt
|
prob = 1.0/self.act_cnt
|
||||||
for i in range(self.act_cnt):
|
for i in range(self.act_cnt):
|
||||||
self.strat[i] = prob
|
self.strat[i] = prob
|
||||||
# return
|
|
||||||
|
|
||||||
|
# 更新累计策略(拆开)
|
||||||
|
# 加权累计
|
||||||
for i in range(self.act_cnt):
|
for i in range(self.act_cnt):
|
||||||
self.strat_sum[i] += wgt * self.strat[i]
|
self.strat_sum[i] += reach_prob * self.strat[i]
|
||||||
|
|
||||||
return self.strat
|
return self.strat
|
||||||
|
|
||||||
|
# 论文Part2(4)
|
||||||
def get_avg_strat(self):
|
def get_avg_strat(self):
|
||||||
avg_strat = [0.0] * self.act_cnt
|
avg_strat = [0.0] * self.act_cnt
|
||||||
normal = sum(self.strat_sum)
|
normal = sum(self.strat_sum)
|
||||||
@@ -39,10 +39,8 @@ class InfoSet:
|
|||||||
for i in range(self.act_cnt):
|
for i in range(self.act_cnt):
|
||||||
avg_strat[i] = self.strat_sum[i] / normal
|
avg_strat[i] = self.strat_sum[i] / normal
|
||||||
else:
|
else:
|
||||||
##
|
|
||||||
prob = 1.0/self.act_cnt
|
prob = 1.0/self.act_cnt
|
||||||
for i in range(self.act_cnt):
|
for i in range(self.act_cnt):
|
||||||
avg_strat[i] = prob
|
avg_strat[i] = prob
|
||||||
# return
|
|
||||||
return avg_strat
|
return avg_strat
|
||||||
|
|
||||||
@@ -1,52 +1,49 @@
|
|||||||
class KuhnPoker:
|
class KuhnPoker:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cards = ['J', 'Q', 'K']
|
self.cards = ['J', 'Q', 'K']
|
||||||
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
|
# self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
|
||||||
|
self.player_cnt = 2
|
||||||
|
|
||||||
# 结构是不是要修改下?
|
|
||||||
def is_terminal(self, history):
|
def is_terminal(self, history):
|
||||||
return history in ['00', '10', '010', '011', '11']
|
return history in ['00', '10', '010', '011', '11']
|
||||||
|
|
||||||
def get_profit(self, cards, history, player):
|
def get_util(self, cards, history, player):
|
||||||
if not self.is_terminal(history):
|
if not self.is_terminal(history):
|
||||||
return 0.0
|
print(f"game not over, invalid util")
|
||||||
|
print("*"*30)
|
||||||
|
|
||||||
card_values = {'J': 0, 'Q': 1, 'K': 2}
|
card_values = {'J': 0, 'Q': 1, 'K': 2}
|
||||||
p0_cardv = card_values[cards[0]]
|
p0_cardv = card_values[cards[0]]
|
||||||
p1_cardv = card_values[cards[1]]
|
p1_cardv = card_values[cards[1]]
|
||||||
|
|
||||||
p0_wins = p0_cardv > p1_cardv
|
p0_wins = p0_cardv > p1_cardv
|
||||||
|
p0_util = 0
|
||||||
if history == '00':
|
if history == '00':
|
||||||
if p0_wins:
|
p0_util = 1 if p0_wins else -1
|
||||||
return 1.0 if player == 0 else -1.0
|
|
||||||
else:
|
|
||||||
return -1.0 if player == 0 else 1.0
|
|
||||||
|
|
||||||
elif history == '10':
|
elif history == '10':
|
||||||
return 1.0 if player == 0 else -1.0
|
p0_util = 1 if player == 0 else -1
|
||||||
|
|
||||||
elif history == '010':
|
|
||||||
return -1.0 if player == 0 else 1.0
|
|
||||||
|
|
||||||
elif history == '011':
|
|
||||||
if p0_wins:
|
|
||||||
return 2.0 if player == 0 else -2.0
|
|
||||||
else:
|
|
||||||
return -2.0 if player == 0 else 2.0
|
|
||||||
|
|
||||||
elif history == '11':
|
elif history == '11':
|
||||||
if p0_wins:
|
p0_util = 2 if p0_wins else -2
|
||||||
return 2.0 if player == 0 else -2.0
|
elif history == '010':
|
||||||
else:
|
p0_util = -1 if p0_wins else 1
|
||||||
return -2.0 if player == 0 else 2.0
|
elif history == '011':
|
||||||
|
p0_util = 2 if p0_wins else -2
|
||||||
|
|
||||||
return 0.0
|
if player == 0:
|
||||||
|
return p0_util
|
||||||
|
elif player == 1:
|
||||||
|
return -p0_util
|
||||||
|
else:
|
||||||
|
print(f"invalid player:{player}")
|
||||||
|
print("*"*30)
|
||||||
|
return 0
|
||||||
|
|
||||||
def get_Info_set(self, card, history, player):
|
def get_Info_set(self, card, history):
|
||||||
return f"{card}.{history}"
|
return f"{card}.{history}"
|
||||||
|
|
||||||
def get_cur_player(self, history):
|
def get_cur_player(self, history):
|
||||||
if self.is_terminal(history):
|
if self.is_terminal(history):
|
||||||
|
print("game over, no player act")
|
||||||
return -1
|
return -1
|
||||||
return len(history) % 2
|
return len(history) % 2
|
||||||
|
|
||||||
|
|||||||
44
src/game/player.py
Normal file
44
src/game/player.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from src.cfr.info_set import InfoSet
|
||||||
|
class Player:
|
||||||
|
def __init__(self, id):
|
||||||
|
self.id = id
|
||||||
|
self.card = None
|
||||||
|
self.info_map = {}
|
||||||
|
self.regret_sum = 0 # 累计regret
|
||||||
|
self.strat_sum = 0 # 累计strat
|
||||||
|
self.util_sum = 0
|
||||||
|
|
||||||
|
def set_card(self, card):
|
||||||
|
self.card = card
|
||||||
|
|
||||||
|
def get_info_set(self, info_key, act_cnt=2):
|
||||||
|
if info_key not in self.info_map:
|
||||||
|
self.info_map[info_key] = InfoSet(act_cnt)
|
||||||
|
return self.info_map[info_key]
|
||||||
|
|
||||||
|
def get_avg_strat(self):
|
||||||
|
return {
|
||||||
|
info_key: info.get_avg_strat()
|
||||||
|
for info_key, info in self.info_map.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def cal_avg_regret(self, cnt):
|
||||||
|
if cnt <= 0:
|
||||||
|
print("invalid test count")
|
||||||
|
total = sum(max(0, max(info.regret_sum)) for info in self.info_map.values())
|
||||||
|
return total / cnt
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.info_map = {}
|
||||||
|
self.card = None
|
||||||
|
self.regret_sum = 0
|
||||||
|
self.strat_sum = 0
|
||||||
|
self.util_sum = 0
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"player:{self.id} , card:{self.card}, info_sets:{len(self.info_map)}"
|
||||||
|
|
||||||
|
def print_info(self):
|
||||||
|
for info_key, info in self.info_map.items():
|
||||||
|
print(f"infoset: '{info_key}', info: {info}")
|
||||||
@@ -3,6 +3,8 @@ import os
|
|||||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
from src.game.kuhn_poker import KuhnPoker
|
from src.game.kuhn_poker import KuhnPoker
|
||||||
|
from src.cfr.info_set import InfoSet
|
||||||
|
from src.game.player import Player
|
||||||
|
|
||||||
def test_game():
|
def test_game():
|
||||||
game = KuhnPoker()
|
game = KuhnPoker()
|
||||||
@@ -22,30 +24,53 @@ def test_game():
|
|||||||
(['Q', 'K'], '011', 1),
|
(['Q', 'K'], '011', 1),
|
||||||
]
|
]
|
||||||
for cards, his, player in test_cases:
|
for cards, his, player in test_cases:
|
||||||
profit = game.get_profit(cards, his, player)
|
profit = game.get_util(cards, his, player)
|
||||||
print(f"player{player}profit: {profit}")
|
print(f"player{player}profit: {profit}")
|
||||||
|
|
||||||
info_cases = [
|
info_cases = [
|
||||||
('K', '0', 0),
|
('JQ', '' , 0), ('JK', '' , 0),
|
||||||
('J', '01', 1),
|
('JQ', '0', 0), ('JK', '0', 0),
|
||||||
('Q', '', 0),
|
('JQ', '1', 1), ('JK', '1', 1),
|
||||||
|
('JQ', '00', 0), ('JK', '00', 0),
|
||||||
|
('JQ', '10', 0), ('KQ', '10', 0),
|
||||||
|
('QK', '01', 1), ('JK', '01', 1),
|
||||||
|
('JQ', '010',0), ('JK', '010',0),
|
||||||
|
('JQ', '011',1), ('JK', '011',1),
|
||||||
]
|
]
|
||||||
|
|
||||||
for card, his, player in info_cases:
|
info_map = {}
|
||||||
info_set = game.get_Info_set(card, his, player)
|
players = [Player(0), Player(1)]
|
||||||
print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'")
|
|
||||||
|
|
||||||
print(f"valid act: {game.get_valid_act('00')}")
|
for cards, his, id in info_cases:
|
||||||
|
player = players[id]
|
||||||
|
cards = [cards[0], cards[1]]
|
||||||
|
info_key = game.get_Info_set(cards[id], his)
|
||||||
|
player.get_info_set(info_key)
|
||||||
|
for p in players:
|
||||||
|
print(f"player {p.id} has {len(p.info_map)} info sets.")
|
||||||
|
for info_key, info in p.info_map.items():
|
||||||
|
info_map[info_key] = info
|
||||||
|
print("infokey:", info_key)
|
||||||
|
|
||||||
test_seq = ['00', '10', '01', '010', '011', '11']
|
|
||||||
|
|
||||||
for seq in test_seq:
|
# for info_key, info in info_map.items():
|
||||||
terminal = game.is_terminal(seq)
|
# strat = info.get_strat(1.0)
|
||||||
if terminal:
|
# avg_strat = info.get_avg_strat()
|
||||||
print(f" '{seq}' terminal")
|
# print(f"infoset: '{info_key}'")
|
||||||
else:
|
# print(f" current strategy: [CK/FD: {strat[0]:.3f}, BT/CL: {strat[1]:.3f}]")
|
||||||
current_player = game.get_cur_player(seq)
|
# print(f" average strategy: [CK/FD: {avg_strat[0]:.3f}, BT/CL: {avg_strat[1]:.3f}]")
|
||||||
print(f" '{seq}' not terminal, cur player: {current_player}")
|
|
||||||
|
# print('=='*20)
|
||||||
|
|
||||||
|
# test_seq = ['00', '10', '01', '010', '011', '11']
|
||||||
|
|
||||||
|
# for seq in test_seq:
|
||||||
|
# terminal = game.is_terminal(seq)
|
||||||
|
# if terminal:
|
||||||
|
# print(f" '{seq}' terminal")
|
||||||
|
# else:
|
||||||
|
# current_player = game.get_cur_player(seq)
|
||||||
|
# print(f" '{seq}' not terminal, cur player: {current_player}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_game()
|
test_game()
|
||||||
Reference in New Issue
Block a user