kubn cfr:0
This commit is contained in:
44
src/cfr/info_set.py
Normal file
44
src/cfr/info_set.py
Normal file
@@ -0,0 +1,44 @@
|
||||
class InfoSet:
|
||||
def __init__(self, act_cnt=2): # 0-1
|
||||
self.act_cnt = act_cnt
|
||||
|
||||
self.regret_sum = [0.0] * act_cnt
|
||||
self.strat = [0.0] * act_cnt
|
||||
self.strat_sum = [0.0] * act_cnt
|
||||
|
||||
def get_strat(self, wgt):
|
||||
|
||||
normal = 0.0
|
||||
for i in range(self.act_cnt):
|
||||
if self.regret_sum[i] > 0:
|
||||
self.strat[i] = self.regret_sum[i]
|
||||
else:
|
||||
self.strat[i] = 0.0
|
||||
normal += self.strat[i]
|
||||
|
||||
if normal > 0:
|
||||
for i in range(self.act_cnt):
|
||||
self.strat[i] /= normal
|
||||
else:
|
||||
##
|
||||
return
|
||||
|
||||
for i in range(self.act_cnt):
|
||||
self.strat_sum[i] += wgt * self.strat[i]
|
||||
|
||||
return self.strat
|
||||
|
||||
def get_avg_strat(self):
|
||||
avg_strat = [0.0] * self.act_cnt
|
||||
normal = sum(self.strat_sum)
|
||||
|
||||
if normal > 0:
|
||||
for i in range(self.act_cnt):
|
||||
avg_strat[i] = self.strat_sum[i] / normal
|
||||
else:
|
||||
##
|
||||
return
|
||||
|
||||
|
||||
return avg_strat
|
||||
|
||||
Reference in New Issue
Block a user