cfr: montecar test

This commit is contained in:
2025-12-04 17:44:53 +08:00
parent 33ca6d59b0
commit fe4e025a8a

104
src/cfr/montecar.py Normal file
View File

@@ -0,0 +1,104 @@
import random
from src.game.kuhn_poker import KuhnPoker
from src.cfr.calcu import test, reset_players, all_avg_strat
game = KuhnPoker()
def get_once_util(strat_p0, strat_p1):
all_card = ['J', 'Q', 'K']
cards = random.sample(all_card, 2)
history = ""
while not game.is_terminal(history):
player = game.get_cur_player(history)
card = cards[player]
info_key = game.get_Info_set(card, history)
if player == 0:
if info_key in strat_p0:
probs = strat_p0[info_key]
else:
probs = [0.5, 0.5]
else:
probs = strat_p1.get(info_key, [0.5, 0.5])
action = 0 if random.random() < probs[0] else 1
history += str(action)
return game.get_util(cards, history, player=0)
def game_avg_util(strat_p0, strat_p1, game_cnt=1000):
total = 0
for _ in range(game_cnt):
total += get_once_util(strat_p0, strat_p1)
return total / game_cnt
def perturb_infoset(strat, info_key):
new_strat = strat.copy()
if info_key in new_strat:
probs = new_strat[info_key].copy()
diff = 0.2
probs[0] += random.uniform(-diff, diff)
probs[1] = 1 - probs[0]
probs[0] = max(0.01, min(0.99, probs[0]))
probs[1] = 1 - probs[0]
probs[0] = round(probs[0], 4)
probs[1] = round(1- probs[0], 4)
new_strat[info_key] = probs
return new_strat
def test_one_info(strat_p0, strat_p1, test_cnt, game_cnt, info='J.'):
print(f"\n{'*'*100}")
print(f"info_p: {info}")
print(f"cfr strategy: {strat_p0.get(info)}")
util_p0 = game_avg_util(strat_p0, strat_p1, game_cnt)
print(f"cfr utility: {util_p0:.6f}\n")
print(f"{'='*60}\n")
print("range\tnew_strat\t\tnew_util\tgain")
max_gain = 0.0
best_strat = None
for i in range(test_cnt):
new_strat = perturb_infoset(strat_p0, info)
perturb_util = game_avg_util(new_strat, strat_p1, game_cnt)
gain = perturb_util - util_p0
if gain > max_gain:
max_gain = gain
best_strat = new_strat
print(f"{i+1}\t{new_strat[info]}\t\t{perturb_util:<12.6f}\t{gain:+12.6f}")
print(f"{'='*60}\n")
print(f"可剥削度: {max_gain:.6f}")
if best_strat:
print(f"扰动后增益最大策略: {best_strat[info]}")
print(f"{'='*60}\n")
if __name__ == "__main__":
reset_players()
test(100000)
strat_p0 = all_avg_strat(0)
strat_p1 = all_avg_strat(1)
test_one_info(strat_p0, strat_p1, 10, 5000, 'J.')
test_one_info(strat_p0, strat_p1, 10, 5000, 'K.')
test_one_info(strat_p0, strat_p1, 10, 5000, 'K.01')