This commit is contained in:
2025-09-28 16:44:00 +08:00
parent 0597239207
commit fc085eb77e
4 changed files with 69 additions and 180 deletions

View File

@@ -1,4 +1,5 @@
import numpy as np
import random
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from poker.card import Card, ShortDeckRank, Suit
@@ -164,17 +165,19 @@ class XTaskDataParser:
self.data_path = data_path
self.decoder = OpenPQLDecoder()
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy") -> List[RiverEHSRecord]:
def parse_river_ehs_with_cards(self, filename: str = "river_ehs.npy", max_records: int = 1000) -> List[RiverEHSRecord]:
filepath = f"{self.data_path}/{filename}"
try:
raw_data = np.load(filepath)
print(f"加载river_EHS: {raw_data.shape} 条记录,数据类型: {raw_data.dtype}")
# 抽样
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
records = []
decode_errors = 0
for i, row in enumerate(raw_data):
for i, row in enumerate(data_to_process):
try:
board_id = int(row['board'])
player_id = int(row['player'])
@@ -245,10 +248,8 @@ class XTaskDataParser:
records = []
decode_errors = 0
# 限制处理数量以避免内存问题
# todo:抽样优化
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
# 抽样
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
for i, row in enumerate(data_to_process):
try:
@@ -305,8 +306,7 @@ class XTaskDataParser:
print(f"加载flop_hist: {raw_data.shape} 条记录,限制处理: {max_records:,} 条记录")
records = []
decode_errors = 0
data_to_process = raw_data[:max_records] if len(raw_data) > max_records else raw_data
data_to_process = random.sample(list(raw_data), min(max_records, len(raw_data)))
for i, row in enumerate(data_to_process):
try: