This commit is contained in:
2025-12-03 17:49:29 +08:00
parent 8e4be3bda2
commit 33ca6d59b0
6 changed files with 241 additions and 91 deletions

View File

@@ -0,0 +1,16 @@
start:cfr([JQ],"",1,1)
# strat_0:J.[0.7,0.3]
0:cfr([JQ],"0",0.7,1)
# strat_1:Q.0[0.6,0.4]
0:cfr([JQ],"00",0.7,0.6) -> terminal,get_uitl
# strat_0:J.01[0.8,0.2]
1:cfr([JQ],"01",0.7,0.4)
0:cfr([JQ],"010",0.56,0.4) -> t
1:cfr([JQ],"011",0.14,0.4) -> t
1:cfr([JQ],"1",0.3,1)
# strat_1:Q.1[0.5,0.5]
0:cfr([JQ],"10",0.3,0.5) -> t
1:cfr([JQ],"11",0.3,0.5) -> t

View File

@@ -1,73 +1,143 @@
import random
from src.game.kuhn_poker import KuhnPoker
from src.cfr.info_set import InfoSet
from src.game.player import Player
info_map = {}
game = KuhnPoker()
players = [Player(0), Player(1)]
def cfr(cards, his, p0, p1):
if game.is_terminal(his):
return game.get_profit(cards, his, 0)
# return {
# 0:game.get_util(cards, his, 0),
# 1:game.get_util(cards, his, 1)
# }
return game.get_util(cards, his, 0)
player = game.get_cur_player(his)
if player == -1:
print(f"game over, invialid player")
return 0.0
# + player ?
info_key = game.get_Info_set(cards[player], his, player)
cur_p = players[player]
cur_p.set_card(cards[player])
if info_key not in info_map:
info_map[info_key] = InfoSet(2) # 0-1
info = info_map[info_key]
info_key = game.get_Info_set(cards[player], his)
info = cur_p.get_info_set(info_key)
wgt = p0 if player == 0 else p1
strat = info.get_strat(wgt)
strat = info.get_strat(p0 if player == 0 else p1)
act_profit = [0.0, 0.0]
info_profit = 0.0 # 期望
act_util = [0.0, 0.0]
info_util = 0.0
for action in [0, 1]:
next_his = his + str(action)
if player == 0:
profit = cfr(cards, next_his, p0*strat[action], p1)
util = cfr(cards, next_his, p0*strat[action], p1)
else:
profit = cfr(cards, next_his, p0, p1*strat[action])
util = cfr(cards, next_his, p0, p1*strat[action])
act_profit[action] = profit
info_profit += strat[action]*profit
act_util[action] = util
info_util += strat[action]*util
# 更新
for action in [0, 1]:
p0_util = act_util[action] - info_util
if player == 0:
regret = act_profit[action] - info_profit
regret = p0_util
other_r = p1
else:
regret = -(act_profit[action] - info_profit)
regret = -p0_util
other_r = p0
# 对手的reach probe加权累计regret
# 假设对手到达这的概率小那么这个regret就不重要
info.regret_sum[action] += other_r * regret
return info_profit
return info_util
def test():
def test(cnt):
cards = ['J', 'Q', 'K']
p_sum = 0.0
util_sum = 0.0
for i in range(10):
for i in range(cnt):
card_r = random.sample(cards, 2)
util = cfr(card_r, "", 1.0, 1.0)
util_sum += util
players[0].util_sum += util
players[1].util_sum -= util
pf = cfr(card_r, "", 1.0, 1.0)
p_sum += pf
if (i + 1) % cnt == 0:
avg_util_p0 = players[0].util_sum / (i + 1)
avg_util_p1 = players[1].util_sum / (i + 1)
avg_regret_p0 = players[0].cal_avg_regret(i + 1)
avg_regret_p1 = players[1].cal_avg_regret(i + 1)
print(
f"cnt: {cnt}\n",
f"avg_util_p0: {avg_util_p0}\n",
f"avg_util_p1: {avg_util_p1}\n",
f"avg_regret_p0: {avg_regret_p0}\n",
f"avg_regret_p1: {avg_regret_p1}\n",
)
def all_avg_strat(id):
return players[id].get_avg_strat()
def print_player_strat(id):
player = players[id]
strategy = player.get_avg_strat()
print(f"\n{'='*75}")
print(f"玩家 {id} 的平均策略 (信息集数: {len(strategy)})")
print(f"{'='*75}")
print(f"{'info_p':<10} {'card':<10} {'history':<10} {'act0':<20} {'act1':<20}")
print(f"{'-'*75}")
for info_key in sorted(strategy.keys()):
prob = strategy[info_key]
parts = info_key.split('.')
card = parts[0]
hist = parts[1] if len(parts) > 1 else ''
act0_name = game.get_act_name(hist, 0)
act1_name = game.get_act_name(hist, 1)
print(f"{info_key:<10} {card:<10} {hist:<10} "
f"{act0_name}:{prob[0]:<10.3f} {act1_name}:{prob[1]:<10.3f}")
print(f"{'='*75}\n")
def print_all_strategies():
print("\n" + "#"*75)
print("#" + " "*25 + "所有玩家的平均策略" + " "*24 + "#")
print("#"*75)
for player_id in [0, 1]:
print_player_strat(player_id)
def reset_players():
for player in players:
player.reset()
if __name__ == "__main__":
reset_players()
test(1000000)
print("平均策略")
print_player_strat(0)
print("策略check")
strats = all_avg_strat(0)
print(f" K_Bet_Probe: {strats.get('K.', [0, 0])[1]:.3f} ")
print(f" Q_Check_Probe: {strats.get('Q.', [0, 0])[0]:.3f} ")
print(f" J_Check_Probe: {strats.get('J.', [0, 0])[0]:.3f} ")
if (i + 1) % 10 == 0:
avg_p = p_sum / (i + 1)
print(f"Range {i+1}/10, Avg : {avg_p:.3f}")
return p_sum / 10
def print_strat():

View File

@@ -2,35 +2,35 @@ class InfoSet:
def __init__(self, act_cnt=2): # 0-1
self.act_cnt = act_cnt
self.regret_sum = [0.0] * act_cnt
self.regret_sum = [0.0] * act_cnt # 累积regret用于更新策略
self.strat = [0.0] * act_cnt
self.strat_sum = [0.0] * act_cnt
self.strat_sum = [0.0] * act_cnt # 累积策略, 用于平均策略
def get_strat(self, wgt):
# 更新策略 论文Part3(8)
def get_strat(self, reach_prob):
normal = 0.0
for i in range(self.act_cnt):
if self.regret_sum[i] > 0:
self.strat[i] = self.regret_sum[i]
else:
self.strat[i] = 0.0
self.strat[i] = max(self.regret_sum[i], 0.0)
normal += self.strat[i]
if normal > 0:
for i in range(self.act_cnt):
self.strat[i] /= normal
else:
##
# 混合策略
prob = 1.0/self.act_cnt
for i in range(self.act_cnt):
self.strat[i] = prob
# return
# 更新累计策略(拆开)
# 加权累计
for i in range(self.act_cnt):
self.strat_sum[i] += wgt * self.strat[i]
self.strat_sum[i] += reach_prob * self.strat[i]
return self.strat
# 论文Part2(4)
def get_avg_strat(self):
avg_strat = [0.0] * self.act_cnt
normal = sum(self.strat_sum)
@@ -39,10 +39,8 @@ class InfoSet:
for i in range(self.act_cnt):
avg_strat[i] = self.strat_sum[i] / normal
else:
##
prob = 1.0/self.act_cnt
for i in range(self.act_cnt):
avg_strat[i] = prob
# return
return avg_strat

View File

@@ -1,52 +1,49 @@
class KuhnPoker:
def __init__(self):
self.cards = ['J', 'Q', 'K']
self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
# self.actions = [0, 1] # Check/Fold=0, Bet/Call=1
self.player_cnt = 2
# 结构是不是要修改下?
def is_terminal(self, history):
return history in ['00', '10', '010', '011', '11']
def get_profit(self, cards, history, player):
def get_util(self, cards, history, player):
if not self.is_terminal(history):
return 0.0
print(f"game not over, invalid util")
print("*"*30)
card_values = {'J': 0, 'Q': 1, 'K': 2}
p0_cardv = card_values[cards[0]]
p1_cardv = card_values[cards[1]]
p0_wins = p0_cardv > p1_cardv
p0_util = 0
if history == '00':
if p0_wins:
return 1.0 if player == 0 else -1.0
else:
return -1.0 if player == 0 else 1.0
p0_util = 1 if p0_wins else -1
elif history == '10':
return 1.0 if player == 0 else -1.0
elif history == '010':
return -1.0 if player == 0 else 1.0
elif history == '011':
if p0_wins:
return 2.0 if player == 0 else -2.0
else:
return -2.0 if player == 0 else 2.0
p0_util = 1 if player == 0 else -1
elif history == '11':
if p0_wins:
return 2.0 if player == 0 else -2.0
p0_util = 2 if p0_wins else -2
elif history == '010':
p0_util = -1 if p0_wins else 1
elif history == '011':
p0_util = 2 if p0_wins else -2
if player == 0:
return p0_util
elif player == 1:
return -p0_util
else:
return -2.0 if player == 0 else 2.0
print(f"invalid player:{player}")
print("*"*30)
return 0
return 0.0
def get_Info_set(self, card, history, player):
def get_Info_set(self, card, history):
return f"{card}.{history}"
def get_cur_player(self, history):
if self.is_terminal(history):
print("game over, no player act")
return -1
return len(history) % 2

44
src/game/player.py Normal file
View File

@@ -0,0 +1,44 @@
from src.cfr.info_set import InfoSet
class Player:
def __init__(self, id):
self.id = id
self.card = None
self.info_map = {}
self.regret_sum = 0 # 累计regret
self.strat_sum = 0 # 累计strat
self.util_sum = 0
def set_card(self, card):
self.card = card
def get_info_set(self, info_key, act_cnt=2):
if info_key not in self.info_map:
self.info_map[info_key] = InfoSet(act_cnt)
return self.info_map[info_key]
def get_avg_strat(self):
return {
info_key: info.get_avg_strat()
for info_key, info in self.info_map.items()
}
def cal_avg_regret(self, cnt):
if cnt <= 0:
print("invalid test count")
total = sum(max(0, max(info.regret_sum)) for info in self.info_map.values())
return total / cnt
def reset(self):
self.info_map = {}
self.card = None
self.regret_sum = 0
self.strat_sum = 0
self.util_sum = 0
def __str__(self):
return f"player:{self.id} , card:{self.card}, info_sets:{len(self.info_map)}"
def print_info(self):
for info_key, info in self.info_map.items():
print(f"infoset: '{info_key}', info: {info}")

View File

@@ -3,6 +3,8 @@ import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.game.kuhn_poker import KuhnPoker
from src.cfr.info_set import InfoSet
from src.game.player import Player
def test_game():
game = KuhnPoker()
@@ -22,30 +24,53 @@ def test_game():
(['Q', 'K'], '011', 1),
]
for cards, his, player in test_cases:
profit = game.get_profit(cards, his, player)
profit = game.get_util(cards, his, player)
print(f"player{player}profit: {profit}")
info_cases = [
('K', '0', 0),
('J', '01', 1),
('Q', '', 0),
('JQ', '' , 0), ('JK', '' , 0),
('JQ', '0', 0), ('JK', '0', 0),
('JQ', '1', 1), ('JK', '1', 1),
('JQ', '00', 0), ('JK', '00', 0),
('JQ', '10', 0), ('KQ', '10', 0),
('QK', '01', 1), ('JK', '01', 1),
('JQ', '010',0), ('JK', '010',0),
('JQ', '011',1), ('JK', '011',1),
]
for card, his, player in info_cases:
info_set = game.get_Info_set(card, his, player)
print(f"card:{card}, his:'{his}', palyer{player} -> infoset: '{info_set}'")
info_map = {}
players = [Player(0), Player(1)]
print(f"valid act: {game.get_valid_act('00')}")
for cards, his, id in info_cases:
player = players[id]
cards = [cards[0], cards[1]]
info_key = game.get_Info_set(cards[id], his)
player.get_info_set(info_key)
for p in players:
print(f"player {p.id} has {len(p.info_map)} info sets.")
for info_key, info in p.info_map.items():
info_map[info_key] = info
print("infokey:", info_key)
test_seq = ['00', '10', '01', '010', '011', '11']
for seq in test_seq:
terminal = game.is_terminal(seq)
if terminal:
print(f" '{seq}' terminal")
else:
current_player = game.get_cur_player(seq)
print(f" '{seq}' not terminal, cur player: {current_player}")
# for info_key, info in info_map.items():
# strat = info.get_strat(1.0)
# avg_strat = info.get_avg_strat()
# print(f"infoset: '{info_key}'")
# print(f" current strategy: [CK/FD: {strat[0]:.3f}, BT/CL: {strat[1]:.3f}]")
# print(f" average strategy: [CK/FD: {avg_strat[0]:.3f}, BT/CL: {avg_strat[1]:.3f}]")
# print('=='*20)
# test_seq = ['00', '10', '01', '010', '011', '11']
# for seq in test_seq:
# terminal = game.is_terminal(seq)
# if terminal:
# print(f" '{seq}' terminal")
# else:
# current_player = game.get_cur_player(seq)
# print(f" '{seq}' not terminal, cur player: {current_player}")
if __name__ == "__main__":
test_game()