task3
This commit is contained in:
77
README.md
77
README.md
@@ -1,24 +1,7 @@
|
|||||||
# Poker Task 3: EMD Distance Calculation on TURN
|
# Poker Task3
|
||||||
|
|
||||||
这是扑克Task 3的实现,专注于计算TURN阶段两手牌力分布的EMD (Earth Mover's Distance) 距离。
|
这是Task3的实现,计算TURN阶段的EMD (Earth Mover's Distance) 距离。
|
||||||
|
|
||||||
## 功能特点
|
|
||||||
|
|
||||||
- **TURN阶段扑克分析**: 基于已知的公共牌和底牌,计算每手牌在所有可能RIVER牌下的牌力分布
|
|
||||||
- **EMD距离计算**: 使用scipy的wasserstein_distance函数计算两个牌力分布的EMD距离
|
|
||||||
- **专业扑克评估**: 集成poker_task1的专业扑克牌评估系统,支持所有标准扑克手牌类型
|
|
||||||
|
|
||||||
## 安装和使用
|
|
||||||
|
|
||||||
### 安装依赖
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 使用uv管理依赖(推荐)
|
|
||||||
uv install
|
|
||||||
|
|
||||||
# 或使用pip
|
|
||||||
pip install scipy
|
|
||||||
```
|
|
||||||
|
|
||||||
### 运行程序
|
### 运行程序
|
||||||
|
|
||||||
@@ -32,57 +15,9 @@ uv run python main.py
|
|||||||
|
|
||||||
### 输入格式
|
### 输入格式
|
||||||
|
|
||||||
程序接受以下格式的输入:
|
|
||||||
```
|
|
||||||
玩家1底牌(两张)
|
|
||||||
玩家2底牌(两张)
|
|
||||||
公共牌(TURN阶段,4张牌)
|
|
||||||
```
|
|
||||||
|
|
||||||
示例:
|
示例:
|
||||||
```
|
```
|
||||||
As Ks
|
AsKs
|
||||||
7s 6s
|
7s6s
|
||||||
8h 9d 9c Qh
|
8h9d9cQh
|
||||||
```
|
```
|
||||||
|
|
||||||
## 运行测试
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 运行所有测试
|
|
||||||
uv run python -m pytest tests/ -v
|
|
||||||
|
|
||||||
# 或
|
|
||||||
python -m pytest tests/ -v
|
|
||||||
```
|
|
||||||
|
|
||||||
## 项目结构
|
|
||||||
|
|
||||||
```
|
|
||||||
poker_task3/
|
|
||||||
├── poker_task3/ # 主要代码包
|
|
||||||
│ ├── card.py # 扑克牌类定义
|
|
||||||
│ ├── hand_evaluator.py # 手牌评估器
|
|
||||||
│ ├── hand_ranking.py # 手牌排名系统
|
|
||||||
│ ├── emd.py # EMD距离计算
|
|
||||||
│ ├── task3.py # Task 3主要实现
|
|
||||||
│ └── __init__.py
|
|
||||||
├── tests/ # 测试文件
|
|
||||||
│ └── test_task3.py # Task 3测试
|
|
||||||
├── main.py # 程序入口
|
|
||||||
└── README.md # 项目文档
|
|
||||||
```
|
|
||||||
|
|
||||||
## 算法原理
|
|
||||||
|
|
||||||
1. **牌力评估**: 对于每手底牌,遍历所有可能的RIVER牌(剩余47张牌)
|
|
||||||
2. **分布构建**: 统计每种手牌类型(High Card到Royal Flush)的出现次数
|
|
||||||
3. **EMD计算**: 使用scipy.stats.wasserstein_distance计算两个分布的EMD距离
|
|
||||||
|
|
||||||
## 测试覆盖
|
|
||||||
|
|
||||||
- 输入解析测试
|
|
||||||
- 牌力分布计算测试
|
|
||||||
- EMD距离计算测试
|
|
||||||
- 错误处理测试
|
|
||||||
- 不同牌力强度对比测试
|
|
||||||
12
main.py
12
main.py
@@ -4,10 +4,16 @@ from poker_task3.task3 import runTask3
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""task3的主入口点"""
|
|
||||||
print("=== task3:扑克牌TURN阶段EMD距离计算 ===")
|
print("=== task3:扑克牌TURN阶段EMD距离计算 ===")
|
||||||
|
print("请输入三行数据:")
|
||||||
distance = runTask3()
|
print("第一行:hand1(如 AsKs)")
|
||||||
|
print("第二行:hand2(如 7s6s)")
|
||||||
|
print("第三行:board cards(如 8h9d9cQh)")
|
||||||
|
lines = []
|
||||||
|
for i in range(3):
|
||||||
|
lines.append(input())
|
||||||
|
input_text = "\n".join(lines)
|
||||||
|
distance = runTask3(input_text)
|
||||||
if distance is not None:
|
if distance is not None:
|
||||||
print(f"EMD距离: {distance}")
|
print(f"EMD距离: {distance}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
"""
|
|
||||||
Card module for poker game
|
|
||||||
"""
|
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
@@ -87,10 +83,7 @@ class Card:
|
|||||||
return self.rank.numeric_value < other.rank.numeric_value
|
return self.rank.numeric_value < other.rank.numeric_value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_card(cls, card_str: str) -> 'Card':
|
def createCard(cls, card_str) -> 'Card':
|
||||||
"""
|
|
||||||
从字符串创建Card对象,例如 "As", "Kh", "2c"
|
|
||||||
"""
|
|
||||||
if len(card_str) != 2:
|
if len(card_str) != 2:
|
||||||
raise ValueError(f"Invalid card string: {card_str}")
|
raise ValueError(f"Invalid card string: {card_str}")
|
||||||
|
|
||||||
@@ -120,7 +113,7 @@ class Card:
|
|||||||
return cls(rank, suit)
|
return cls(rank, suit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_cards(cls, cards_str: str) -> List['Card']:
|
def parseCards(cls, cards_str) -> List['Card']:
|
||||||
"""
|
"""
|
||||||
从字符串解析多张牌,例如 "AsKs AhAdAc6s7s"
|
从字符串解析多张牌,例如 "AsKs AhAdAc6s7s"
|
||||||
"""
|
"""
|
||||||
@@ -142,8 +135,4 @@ class Card:
|
|||||||
i += 2
|
i += 2
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid card format at position {i}")
|
raise ValueError(f"Invalid card format at position {i}")
|
||||||
result = []
|
return [cls.createCard(card_str) for card_str in card_strings]
|
||||||
for card_str in card_strings:
|
|
||||||
card_tmp = cls.create_card(card_str)
|
|
||||||
result.append(card_tmp)
|
|
||||||
return result
|
|
||||||
@@ -1,7 +1,3 @@
|
|||||||
"""
|
|
||||||
Hand evaluator module for poker game
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple, Dict
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
@@ -10,14 +6,10 @@ from .hand_ranking import HandRanking, HandType
|
|||||||
|
|
||||||
|
|
||||||
class HandEvaluator:
|
class HandEvaluator:
|
||||||
"""
|
|
||||||
手牌评估器类,用于评估扑克手牌
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def evaluate_hand(cards) -> HandRanking:
|
def evaluateHand(cards) -> HandRanking:
|
||||||
"""
|
"""
|
||||||
从7张牌中评估出最好的5张牌组合
|
从7张牌中找出最好的5张牌组合
|
||||||
"""
|
"""
|
||||||
if len(cards) != 7:
|
if len(cards) != 7:
|
||||||
raise ValueError(f"Expected 7 cards, got {len(cards)}")
|
raise ValueError(f"Expected 7 cards, got {len(cards)}")
|
||||||
@@ -25,23 +17,19 @@ class HandEvaluator:
|
|||||||
best_ranking = None
|
best_ranking = None
|
||||||
best_cards = None
|
best_cards = None
|
||||||
|
|
||||||
# 尝试所有可能的5张牌组合
|
# 所有可能的5张牌组合
|
||||||
for five_cards in combinations(cards, 5):
|
for five_cards in combinations(cards, 5):
|
||||||
ranking = HandEvaluator._evaluate_five_cards(list(five_cards))
|
ranking = HandEvaluator.evaluate5Cards(list(five_cards))
|
||||||
|
|
||||||
if best_ranking is None or ranking > best_ranking:
|
if best_ranking is None or ranking > best_ranking:
|
||||||
best_ranking = ranking
|
best_ranking = ranking
|
||||||
best_cards = list(five_cards)
|
best_cards = list(five_cards)
|
||||||
|
|
||||||
# 更新最佳ranking的cards
|
|
||||||
best_ranking.cards = best_cards
|
best_ranking.cards = best_cards
|
||||||
return best_ranking
|
return best_ranking
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _evaluate_five_cards(cards: List[Card]) -> HandRanking:
|
def evaluate5Cards(cards) -> HandRanking:
|
||||||
"""
|
|
||||||
评估5张牌的手牌类型
|
|
||||||
"""
|
|
||||||
if len(cards) != 5:
|
if len(cards) != 5:
|
||||||
raise ValueError(f"Expected 5 cards, got {len(cards)}")
|
raise ValueError(f"Expected 5 cards, got {len(cards)}")
|
||||||
|
|
||||||
@@ -54,11 +42,11 @@ class HandEvaluator:
|
|||||||
rank_counts = Counter(ranks)
|
rank_counts = Counter(ranks)
|
||||||
count_values = sorted(rank_counts.values(), reverse=True)
|
count_values = sorted(rank_counts.values(), reverse=True)
|
||||||
|
|
||||||
# 检查是否是同花
|
# 同花
|
||||||
is_flush = len(set(suits)) == 1
|
is_flush = len(set(suits)) == 1
|
||||||
|
|
||||||
# 检查是否是顺子
|
# 顺子
|
||||||
is_straight, straight_high = HandEvaluator._is_straight(ranks)
|
is_straight, straight_high = HandEvaluator._isStraight(ranks)
|
||||||
|
|
||||||
# 根据牌型返回相应的HandRanking
|
# 根据牌型返回相应的HandRanking
|
||||||
if is_straight and is_flush:
|
if is_straight and is_flush:
|
||||||
@@ -102,14 +90,10 @@ class HandEvaluator:
|
|||||||
return HandRanking(HandType.HIGH_CARD, ranks, sorted_cards)
|
return HandRanking(HandType.HIGH_CARD, ranks, sorted_cards)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_straight(ranks: List[Rank]) -> Tuple[bool, Rank]:
|
def _isStraight(ranks: List[Rank]) -> Tuple[bool, Rank]:
|
||||||
"""
|
|
||||||
检查是否是顺子,返回(是否是顺子, 最高牌)
|
|
||||||
"""
|
|
||||||
# 排序点数值
|
|
||||||
values = sorted([rank.numeric_value for rank in ranks], reverse=True)
|
values = sorted([rank.numeric_value for rank in ranks], reverse=True)
|
||||||
|
|
||||||
# 检查常规顺子
|
|
||||||
is_regular_straight = True
|
is_regular_straight = True
|
||||||
for i in range(1, len(values)):
|
for i in range(1, len(values)):
|
||||||
if values[i-1] - values[i] != 1:
|
if values[i-1] - values[i] != 1:
|
||||||
@@ -125,8 +109,7 @@ class HandEvaluator:
|
|||||||
break
|
break
|
||||||
return True, highest_rank
|
return True, highest_rank
|
||||||
|
|
||||||
# 检查A-2-3-4-5的特殊顺子(轮子)
|
|
||||||
if values == [14, 5, 4, 3, 2]: # A, 5, 4, 3, 2
|
if values == [14, 5, 4, 3, 2]: # A, 5, 4, 3, 2
|
||||||
return True, Rank.FIVE # 在轮子中,5是最高牌
|
return True, Rank.FIVE
|
||||||
|
|
||||||
return False, None
|
return False, None
|
||||||
@@ -1,16 +1,9 @@
|
|||||||
"""
|
|
||||||
Hand ranking module for poker game
|
|
||||||
"""
|
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from .card import Card, Rank
|
from .card import Card, Rank
|
||||||
|
|
||||||
|
|
||||||
class HandType(Enum):
|
class HandType(Enum):
|
||||||
"""
|
|
||||||
手牌类型枚举,按强度排序
|
|
||||||
"""
|
|
||||||
HIGH_CARD = (1, "High Card")
|
HIGH_CARD = (1, "High Card")
|
||||||
ONE_PAIR = (2, "Pair")
|
ONE_PAIR = (2, "Pair")
|
||||||
TWO_PAIR = (3, "Two Pair")
|
TWO_PAIR = (3, "Two Pair")
|
||||||
@@ -29,44 +22,16 @@ class HandType(Enum):
|
|||||||
obj.type_name = name
|
obj.type_name = name
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.type_name
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
if not isinstance(other, HandType):
|
|
||||||
return NotImplemented
|
|
||||||
return self.strength < other.strength
|
|
||||||
|
|
||||||
def __le__(self, other):
|
|
||||||
if not isinstance(other, HandType):
|
|
||||||
return NotImplemented
|
|
||||||
return self.strength <= other.strength
|
|
||||||
|
|
||||||
def __gt__(self, other):
|
|
||||||
if not isinstance(other, HandType):
|
|
||||||
return NotImplemented
|
|
||||||
return self.strength > other.strength
|
|
||||||
|
|
||||||
def __ge__(self, other):
|
|
||||||
if not isinstance(other, HandType):
|
|
||||||
return NotImplemented
|
|
||||||
return self.strength >= other.strength
|
|
||||||
|
|
||||||
|
|
||||||
class HandRanking:
|
class HandRanking:
|
||||||
"""
|
|
||||||
手牌排名类,包含手牌类型和关键牌
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hand_type: HandType, key_ranks: List[Rank], cards: List[Card]):
|
def __init__(self, hand_type: HandType, key_ranks: List[Rank], cards: List[Card]):
|
||||||
self.hand_type = hand_type
|
self.hand_type = hand_type
|
||||||
self.key_ranks = key_ranks # 用于比较的关键点数
|
self.key_ranks = key_ranks # 用于比较的关键点数
|
||||||
self.cards = cards # 组成这个排名的5张牌
|
self.cards = cards # 组成这个ranking的5张牌
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""
|
|
||||||
返回手牌排名的字符串表示
|
|
||||||
"""
|
|
||||||
if self.hand_type == HandType.FOUR_OF_A_KIND:
|
if self.hand_type == HandType.FOUR_OF_A_KIND:
|
||||||
return f"Quad({self.key_ranks[0].symbol})"
|
return f"Quad({self.key_ranks[0].symbol})"
|
||||||
elif self.hand_type == HandType.FULL_HOUSE:
|
elif self.hand_type == HandType.FULL_HOUSE:
|
||||||
@@ -88,38 +53,5 @@ class HandRanking:
|
|||||||
return f"Two Pair({self.key_ranks[0].symbol} and {self.key_ranks[1].symbol})"
|
return f"Two Pair({self.key_ranks[0].symbol} and {self.key_ranks[1].symbol})"
|
||||||
elif self.hand_type == HandType.ONE_PAIR:
|
elif self.hand_type == HandType.ONE_PAIR:
|
||||||
return f"Pair({self.key_ranks[0].symbol})"
|
return f"Pair({self.key_ranks[0].symbol})"
|
||||||
else: # HIGH_CARD
|
else:
|
||||||
return f"High Card({self.key_ranks[0].symbol})"
|
return f"High Card({self.key_ranks[0].symbol})"
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"HandRanking({self.hand_type}, {[r.symbol for r in self.key_ranks]})"
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, HandRanking):
|
|
||||||
return False
|
|
||||||
return (self.hand_type == other.hand_type and
|
|
||||||
self.key_ranks == other.key_ranks)
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
if not isinstance(other, HandRanking):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
# 首先比较手牌类型
|
|
||||||
if self.hand_type != other.hand_type:
|
|
||||||
return self.hand_type < other.hand_type
|
|
||||||
|
|
||||||
# 如果手牌类型相同,比较关键点数
|
|
||||||
for self_rank, other_rank in zip(self.key_ranks, other.key_ranks):
|
|
||||||
if self_rank != other_rank:
|
|
||||||
return self_rank < other_rank
|
|
||||||
|
|
||||||
return False # 完全相等
|
|
||||||
|
|
||||||
def __le__(self, other):
|
|
||||||
return self == other or self < other
|
|
||||||
|
|
||||||
def __gt__(self, other):
|
|
||||||
return not self <= other
|
|
||||||
|
|
||||||
def __ge__(self, other):
|
|
||||||
return not self < other
|
|
||||||
@@ -7,132 +7,106 @@ from .card import Card, Rank, Suit
|
|||||||
from .hand_evaluator import HandEvaluator
|
from .hand_evaluator import HandEvaluator
|
||||||
from .hand_ranking import HandRanking, HandType
|
from .hand_ranking import HandRanking, HandType
|
||||||
|
|
||||||
|
def parseTurnInput(input_text) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]:
|
||||||
def calEmd(Hist1: List[Union[int, float]],
|
|
||||||
Hist2: List[Union[int, float]]) -> float:
|
|
||||||
return wasserstein_distance(Hist1, Hist2)
|
|
||||||
|
|
||||||
|
|
||||||
def parseTurnInput(input_text: str) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]:
|
|
||||||
"""
|
|
||||||
解析TURN阶段输入
|
|
||||||
格式: "As Ks\n7s 6s\n8h 9d 9c Qh"
|
|
||||||
返回: (手牌1, 手牌2, 公共牌)
|
|
||||||
"""
|
|
||||||
lines = [line.strip() for line in input_text.strip().split('\n') if line.strip()]
|
lines = [line.strip() for line in input_text.strip().split('\n') if line.strip()]
|
||||||
|
|
||||||
if len(lines) != 3:
|
if len(lines) != 3:
|
||||||
print("输入必须包含恰好3行:手牌1,手牌2,公共牌")
|
print("输入必须包含恰好3行:hand1,hand2,board cards")
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
hand1_cards = Card.parse_cards(lines[0])
|
hand1_cards = Card.parseCards(lines[0])
|
||||||
hand2_cards = Card.parse_cards(lines[1])
|
hand2_cards = Card.parseCards(lines[1])
|
||||||
board_cards = Card.parse_cards(lines[2])
|
board_cards = Card.parseCards(lines[2])
|
||||||
|
|
||||||
if len(hand1_cards) != 2:
|
if len(hand1_cards) != 2:
|
||||||
print("手牌1必须包含恰好2张牌")
|
print(f"{hand1_cards} should contain exactly 2 cards")
|
||||||
return None, None, None
|
return None, None, None
|
||||||
if len(hand2_cards) != 2:
|
if len(hand2_cards) != 2:
|
||||||
print("手牌2必须包含恰好2张牌")
|
print(f"{hand2_cards} should contain exactly 2 cards")
|
||||||
return None, None, None
|
return None, None, None
|
||||||
if len(board_cards) != 4:
|
if len(board_cards) != 4:
|
||||||
print("公共牌必须包含恰好4张牌(TURN阶段)")
|
print(f"{board_cards} should contain exactly 4 cards")
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
return hand1_cards, hand2_cards, board_cards
|
return hand1_cards, hand2_cards, board_cards
|
||||||
|
|
||||||
|
|
||||||
def calHandEquityHist(hole_cards, board_cards) -> List[float]:
|
def calHandEquityHist(hole_cards, board_cards) -> List[int]:
|
||||||
"""
|
|
||||||
计算TURN阶段手牌的胜率分布
|
|
||||||
基于每个可能河牌下的胜率创建分布,而不是绝对强度排名
|
|
||||||
"""
|
|
||||||
all_cards = hole_cards + board_cards # 总共6张牌
|
all_cards = hole_cards + board_cards # 总共6张牌
|
||||||
|
|
||||||
# 创建剩余牌组
|
|
||||||
used_cards = set(all_cards)
|
used_cards = set(all_cards)
|
||||||
remaining_deck = []
|
remaining_cards = []
|
||||||
|
|
||||||
for rank in Rank:
|
for rank in Rank:
|
||||||
for suit in Suit:
|
for suit in Suit:
|
||||||
card = Card(rank, suit)
|
card = Card(rank, suit)
|
||||||
if card not in used_cards:
|
if card not in used_cards:
|
||||||
remaining_deck.append(card)
|
remaining_cards.append(card)
|
||||||
|
|
||||||
# 为每个可能的河牌计算该手牌的胜率
|
# 为每个可能的河牌计算该手牌的胜率
|
||||||
winrates = []
|
winrates = []
|
||||||
|
|
||||||
for river_card in remaining_deck:
|
for river_card in remaining_cards:
|
||||||
# 当前手牌在这个河牌下的最终七张牌
|
# 当前手牌在这个河牌下的最终七张牌
|
||||||
current_seven_cards = all_cards + [river_card]
|
player1_7cards = all_cards + [river_card]
|
||||||
current_hand_ranking = HandEvaluator.evaluate_hand(current_seven_cards)
|
player1_hand_ranking = HandEvaluator.evaluateHand(player1_7cards)
|
||||||
|
|
||||||
# 计算对抗所有可能对手牌的胜率
|
# 计算对抗所有可能对手牌的胜率
|
||||||
wins = 0
|
wins = 0
|
||||||
total_opponents = 0
|
total_wins = 0
|
||||||
|
|
||||||
# 生成所有可能的对手双牌组合
|
# 生成所有可能的对手牌组合
|
||||||
used_cards_with_river = used_cards | {river_card}
|
available_cards = [card for card in remaining_cards if card != river_card]
|
||||||
available_cards = [card for card in remaining_deck if card != river_card]
|
|
||||||
|
|
||||||
for i in range(len(available_cards)):
|
for i in range(len(available_cards)):
|
||||||
for j in range(i + 1, len(available_cards)):
|
for j in range(i + 1, len(available_cards)):
|
||||||
opponent_cards = [available_cards[i], available_cards[j]]
|
player2_cards = [available_cards[i], available_cards[j]]
|
||||||
opponent_seven_cards = opponent_cards + board_cards + [river_card]
|
player2_7cards = player2_cards + board_cards + [river_card]
|
||||||
opponent_hand_ranking = HandEvaluator.evaluate_hand(opponent_seven_cards)
|
player2_hand_ranking = HandEvaluator.evaluateHand(player2_7cards)
|
||||||
|
|
||||||
total_opponents += 1
|
total_wins += 1
|
||||||
if current_hand_ranking > opponent_hand_ranking:
|
if player1_hand_ranking > player2_hand_ranking:
|
||||||
wins += 1
|
wins += 1
|
||||||
elif current_hand_ranking == opponent_hand_ranking:
|
elif player1_hand_ranking == player2_hand_ranking:
|
||||||
wins += 0.5 # 平局
|
wins += 0.5 # 平局
|
||||||
|
|
||||||
winrate = wins / total_opponents if total_opponents > 0 else 0.0
|
winrate = wins / total_wins if total_wins > 0 else 0.0
|
||||||
winrates.append(winrate)
|
winrates.append(winrate)
|
||||||
|
|
||||||
num_bins = 30
|
num_bins = 30
|
||||||
hist = [0.0] * num_bins
|
hist = [0] * num_bins
|
||||||
|
|
||||||
for winrate in winrates:
|
for winrate in winrates:
|
||||||
bin_index = min(int(winrate * num_bins), num_bins - 1)
|
bin_index = min(int(winrate * num_bins), num_bins - 1)
|
||||||
hist[bin_index] += 1.0
|
hist[bin_index] += 1
|
||||||
|
|
||||||
total = sum(hist)
|
|
||||||
if total > 0:
|
|
||||||
hist = [x / total for x in hist]
|
|
||||||
|
|
||||||
return hist
|
return hist
|
||||||
|
|
||||||
|
|
||||||
def calPokerEmdTurn(input_text: str) -> Union[float, None]:
|
def calPokerEmdTurn(input_text) -> Union[float, None]:
|
||||||
"""
|
|
||||||
计算TURN阶段两手牌的EMD距离
|
|
||||||
"""
|
|
||||||
hand1_cards, hand2_cards, board_cards = parseTurnInput(input_text)
|
hand1_cards, hand2_cards, board_cards = parseTurnInput(input_text)
|
||||||
|
|
||||||
# 检查解析是否成功
|
|
||||||
if hand1_cards is None or hand2_cards is None or board_cards is None:
|
if hand1_cards is None or hand2_cards is None or board_cards is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 计算胜率分布
|
# 计算胜率分布
|
||||||
dist1 = calHandEquityHist(hand1_cards, board_cards)
|
hist1 = calHandEquityHist(hand1_cards, board_cards)
|
||||||
dist2 = calHandEquityHist(hand2_cards, board_cards)
|
hist2 = calHandEquityHist(hand2_cards, board_cards)
|
||||||
|
|
||||||
# 计算EMD距离
|
# 计算EMD距离
|
||||||
emd_distance = calEmd(dist1, dist2)
|
emd_distance = wasserstein_distance(hist1, hist2)
|
||||||
|
|
||||||
return emd_distance
|
return emd_distance
|
||||||
|
|
||||||
|
|
||||||
def runTask3():
|
def runTask3(input_str):
|
||||||
"""运行任务3的示例输入"""
|
if input_str is None:
|
||||||
example_input = """
|
input_str = """
|
||||||
As Ks
|
AsKs
|
||||||
7s 6s
|
7s6s
|
||||||
8h 9d 9c Qh"""
|
8h9d9cQh
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
distance = calPokerEmdTurn(example_input)
|
distance = calPokerEmdTurn(input_str)
|
||||||
return distance
|
return distance
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"错误: {e}")
|
print(f"错误: {e}")
|
||||||
|
|||||||
@@ -1,100 +0,0 @@
|
|||||||
"""测试任务3 - 使用poker_task1进行扑克牌EMD距离计算"""
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
||||||
|
|
||||||
# 从poker_task3.task3模块导入函数
|
|
||||||
from poker_task3.task3 import parseTurnInput, calHandEquityHist, calPokerEmdTurn
|
|
||||||
from poker_task3.card import Card
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_turn_input():
|
|
||||||
"""测试TURN阶段输入解析"""
|
|
||||||
input_text = """
|
|
||||||
As Ks
|
|
||||||
7s 6s
|
|
||||||
8h 9d 9c Qh
|
|
||||||
"""
|
|
||||||
|
|
||||||
hand1, hand2, board = parseTurnInput(input_text)
|
|
||||||
|
|
||||||
assert len(hand1) == 2
|
|
||||||
assert len(hand2) == 2
|
|
||||||
assert len(board) == 4
|
|
||||||
|
|
||||||
assert str(hand1[0]) == "As"
|
|
||||||
assert str(hand1[1]) == "Ks"
|
|
||||||
assert str(hand2[0]) == "7s"
|
|
||||||
assert str(hand2[1]) == "6s"
|
|
||||||
|
|
||||||
print("✓ TURN输入解析测试通过")
|
|
||||||
|
|
||||||
|
|
||||||
def test_hand_equity_Hist():
|
|
||||||
"""测试手牌胜率分布计算"""
|
|
||||||
hole_cards = Card.parse_cards("As Ks")
|
|
||||||
board_cards = Card.parse_cards("8h 9d 9c Qh")
|
|
||||||
|
|
||||||
Hist = calHandEquityHist(hole_cards, board_cards)
|
|
||||||
|
|
||||||
assert len(Hist) == 30 # 胜率分布的区间数量(改进为30个区间)
|
|
||||||
assert all(isinstance(x, (int, float)) for x in Hist)
|
|
||||||
assert all(x >= 0 for x in Hist) # 所有值都应该是非负的
|
|
||||||
assert abs(sum(Hist) - 1.0) < 1e-10 # 分布总和应该为1(标准化后)
|
|
||||||
|
|
||||||
print("✓ 手牌胜率分布测试通过")
|
|
||||||
|
|
||||||
|
|
||||||
def test_poker_emd_calculation():
|
|
||||||
"""测试主要的EMD计算函数"""
|
|
||||||
input_text = """As Ks
|
|
||||||
7s 6s
|
|
||||||
8h 9d 9c Qh"""
|
|
||||||
|
|
||||||
distance = calPokerEmdTurn(input_text)
|
|
||||||
|
|
||||||
assert isinstance(distance, (int, float))
|
|
||||||
assert distance >= 0 # EMD总是非负的
|
|
||||||
|
|
||||||
print(f"✓ 扑克牌EMD计算测试通过 (距离: {distance:.3f})")
|
|
||||||
|
|
||||||
|
|
||||||
def test_different_hand_strengths():
|
|
||||||
"""测试不同强度手牌的EMD"""
|
|
||||||
# 强牌 vs 弱牌
|
|
||||||
input_text = """As Ks
|
|
||||||
2c 3d
|
|
||||||
8h 9d 9c Qh"""
|
|
||||||
|
|
||||||
distance = calPokerEmdTurn(input_text)
|
|
||||||
|
|
||||||
assert isinstance(distance, (int, float))
|
|
||||||
assert distance >= 0 # 应该有非负距离
|
|
||||||
|
|
||||||
print(f"✓ 不同手牌强度测试通过 (距离: {distance:.3f})")
|
|
||||||
|
|
||||||
|
|
||||||
def test_error_handling():
|
|
||||||
"""测试无效输入的错误处理"""
|
|
||||||
# 错误的行数
|
|
||||||
result = parseTurnInput("As Ks\n7s 6s")
|
|
||||||
assert result == (None, None, None), "应该返回(None, None, None)"
|
|
||||||
|
|
||||||
# 手牌中牌数错误
|
|
||||||
result = parseTurnInput("As Ks Qs\n7s 6s\n8h 9d 9c Qh")
|
|
||||||
assert result == (None, None, None), "应该返回(None, None, None)"
|
|
||||||
|
|
||||||
# 公共牌数错误
|
|
||||||
result = parseTurnInput("As Ks\n7s 6s\n8h 9d 9c")
|
|
||||||
assert result == (None, None, None), "应该返回(None, None, None)"
|
|
||||||
|
|
||||||
print("✓ 错误处理测试通过")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_parse_turn_input()
|
|
||||||
test_hand_equity_Hist()
|
|
||||||
test_poker_emd_calculation()
|
|
||||||
test_different_hand_strengths()
|
|
||||||
test_error_handling()
|
|
||||||
print("所有任务3测试通过! ✓")
|
|
||||||
Reference in New Issue
Block a user