plt_hist
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import random
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from poker.card import Card, ShortDeckRank, Suit
|
||||
@@ -164,17 +165,19 @@ class XTaskDataParser:
|
||||
self.data_path = data_path
|
||||
self.decoder = OpenPQLDecoder()
|
||||
|
||||
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]:
|
||||
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]:
|
||||
filepath = f"{self.data_path}/{filename}"
|
||||
|
||||
try:
|
||||
raw_data = np.load(filepath)
|
||||
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
|
||||
|
||||
# 抽样
|
||||
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
||||
|
||||
records = []
|
||||
decode_errors = 0
|
||||
|
||||
for i, row in enumerate(raw_data):
|
||||
|
||||
for i, row in enumerate(data_to_process):
|
||||
try:
|
||||
board_id = int(row['board'])
|
||||
player_id = int(row['player'])
|
||||
@@ -245,10 +248,8 @@ class XTaskDataParser:
|
||||
|
||||
records = []
|
||||
decode_errors = 0
|
||||
|
||||
# 限制处理数量以避免内存问题
|
||||
# todo:抽样优化
|
||||
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
|
||||
# 抽样
|
||||
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
||||
|
||||
for i, row in enumerate(data_to_process):
|
||||
try:
|
||||
@@ -305,8 +306,7 @@ class XTaskDataParser:
|
||||
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
|
||||
records = []
|
||||
decode_errors = 0
|
||||
|
||||
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
|
||||
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
|
||||
|
||||
for i, row in enumerate(data_to_process):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user