This commit is contained in:
2025-09-23 09:15:05 +08:00
parent 6e4973da7a
commit e8a705c467
7 changed files with 72 additions and 353 deletions

View File

@@ -1,24 +1,7 @@
# Poker Task 3: EMD Distance Calculation on TURN # Poker Task3
这是扑克Task 3的实现专注于计算TURN阶段两手牌力分布的EMD (Earth Mover's Distance) 距离。 这是Task3的实现计算TURN阶段的EMD (Earth Mover's Distance) 距离。
## 功能特点
- **TURN阶段扑克分析**: 基于已知的公共牌和底牌计算每手牌在所有可能RIVER牌下的牌力分布
- **EMD距离计算**: 使用scipy的wasserstein_distance函数计算两个牌力分布的EMD距离
- **专业扑克评估**: 集成poker_task1的专业扑克牌评估系统支持所有标准扑克手牌类型
## 安装和使用
### 安装依赖
```bash
# 使用uv管理依赖推荐
uv install
# 或使用pip
pip install scipy
```
### 运行程序 ### 运行程序
@@ -32,57 +15,9 @@ uv run python main.py
### 输入格式 ### 输入格式
程序接受以下格式的输入:
```
玩家1底牌两张
玩家2底牌两张
公共牌TURN阶段4张牌
```
示例: 示例:
``` ```
AsKs AsKs
7s6s 7s6s
8h9d9cQh 8h9d9cQh
``` ```
## 运行测试
```bash
# 运行所有测试
uv run python -m pytest tests/ -v
# 或
python -m pytest tests/ -v
```
## 项目结构
```
poker_task3/
├── poker_task3/ # 主要代码包
│ ├── card.py # 扑克牌类定义
│ ├── hand_evaluator.py # 手牌评估器
│ ├── hand_ranking.py # 手牌排名系统
│ ├── emd.py # EMD距离计算
│ ├── task3.py # Task 3主要实现
│ └── __init__.py
├── tests/ # 测试文件
│ └── test_task3.py # Task 3测试
├── main.py # 程序入口
└── README.md # 项目文档
```
## 算法原理
1. **牌力评估**: 对于每手底牌遍历所有可能的RIVER牌剩余47张牌
2. **分布构建**: 统计每种手牌类型High Card到Royal Flush的出现次数
3. **EMD计算**: 使用scipy.stats.wasserstein_distance计算两个分布的EMD距离
## 测试覆盖
- 输入解析测试
- 牌力分布计算测试
- EMD距离计算测试
- 错误处理测试
- 不同牌力强度对比测试

12
main.py
View File

@@ -4,10 +4,16 @@ from poker_task3.task3 import runTask3
def main(): def main():
"""task3的主入口点"""
print("=== task3扑克牌TURN阶段EMD距离计算 ===") print("=== task3扑克牌TURN阶段EMD距离计算 ===")
print("请输入三行数据:")
distance = runTask3() print("第一行hand1如 AsKs")
print("第二行hand2如 7s6s")
print("第三行board cards如 8h9d9cQh")
lines = []
for i in range(3):
lines.append(input())
input_text = "\n".join(lines)
distance = runTask3(input_text)
if distance is not None: if distance is not None:
print(f"EMD距离: {distance}") print(f"EMD距离: {distance}")

View File

@@ -1,7 +1,3 @@
"""
Card module for poker game
"""
from enum import Enum from enum import Enum
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
@@ -87,10 +83,7 @@ class Card:
return self.rank.numeric_value < other.rank.numeric_value return self.rank.numeric_value < other.rank.numeric_value
@classmethod @classmethod
def create_card(cls, card_str: str) -> 'Card': def createCard(cls, card_str) -> 'Card':
"""
从字符串创建Card对象例如 "As", "Kh", "2c"
"""
if len(card_str) != 2: if len(card_str) != 2:
raise ValueError(f"Invalid card string: {card_str}") raise ValueError(f"Invalid card string: {card_str}")
@@ -120,7 +113,7 @@ class Card:
return cls(rank, suit) return cls(rank, suit)
@classmethod @classmethod
def parse_cards(cls, cards_str: str) -> List['Card']: def parseCards(cls, cards_str) -> List['Card']:
""" """
从字符串解析多张牌,例如 "AsKs AhAdAc6s7s" 从字符串解析多张牌,例如 "AsKs AhAdAc6s7s"
""" """
@@ -142,8 +135,4 @@ class Card:
i += 2 i += 2
else: else:
raise ValueError(f"Invalid card format at position {i}") raise ValueError(f"Invalid card format at position {i}")
result = [] return [cls.createCard(card_str) for card_str in card_strings]
for card_str in card_strings:
card_tmp = cls.create_card(card_str)
result.append(card_tmp)
return result

View File

@@ -1,7 +1,3 @@
"""
Hand evaluator module for poker game
"""
from typing import List, Tuple, Dict from typing import List, Tuple, Dict
from collections import Counter from collections import Counter
from itertools import combinations from itertools import combinations
@@ -10,14 +6,10 @@ from .hand_ranking import HandRanking, HandType
class HandEvaluator: class HandEvaluator:
"""
手牌评估器类,用于评估扑克手牌
"""
@staticmethod @staticmethod
def evaluate_hand(cards) -> HandRanking: def evaluateHand(cards) -> HandRanking:
""" """
从7张牌中评估出最好的5张牌组合 从7张牌中出最好的5张牌组合
""" """
if len(cards) != 7: if len(cards) != 7:
raise ValueError(f"Expected 7 cards, got {len(cards)}") raise ValueError(f"Expected 7 cards, got {len(cards)}")
@@ -25,23 +17,19 @@ class HandEvaluator:
best_ranking = None best_ranking = None
best_cards = None best_cards = None
# 尝试所有可能的5张牌组合 # 所有可能的5张牌组合
for five_cards in combinations(cards, 5): for five_cards in combinations(cards, 5):
ranking = HandEvaluator._evaluate_five_cards(list(five_cards)) ranking = HandEvaluator.evaluate5Cards(list(five_cards))
if best_ranking is None or ranking > best_ranking: if best_ranking is None or ranking > best_ranking:
best_ranking = ranking best_ranking = ranking
best_cards = list(five_cards) best_cards = list(five_cards)
# 更新最佳ranking的cards
best_ranking.cards = best_cards best_ranking.cards = best_cards
return best_ranking return best_ranking
@staticmethod @staticmethod
def _evaluate_five_cards(cards: List[Card]) -> HandRanking: def evaluate5Cards(cards) -> HandRanking:
"""
评估5张牌的手牌类型
"""
if len(cards) != 5: if len(cards) != 5:
raise ValueError(f"Expected 5 cards, got {len(cards)}") raise ValueError(f"Expected 5 cards, got {len(cards)}")
@@ -54,11 +42,11 @@ class HandEvaluator:
rank_counts = Counter(ranks) rank_counts = Counter(ranks)
count_values = sorted(rank_counts.values(), reverse=True) count_values = sorted(rank_counts.values(), reverse=True)
# 检查是否是同花 # 同花
is_flush = len(set(suits)) == 1 is_flush = len(set(suits)) == 1
# 检查是否是顺子 # 顺子
is_straight, straight_high = HandEvaluator._is_straight(ranks) is_straight, straight_high = HandEvaluator._isStraight(ranks)
# 根据牌型返回相应的HandRanking # 根据牌型返回相应的HandRanking
if is_straight and is_flush: if is_straight and is_flush:
@@ -102,14 +90,10 @@ class HandEvaluator:
return HandRanking(HandType.HIGH_CARD, ranks, sorted_cards) return HandRanking(HandType.HIGH_CARD, ranks, sorted_cards)
@staticmethod @staticmethod
def _is_straight(ranks: List[Rank]) -> Tuple[bool, Rank]: def _isStraight(ranks: List[Rank]) -> Tuple[bool, Rank]:
"""
检查是否是顺子,返回(是否是顺子, 最高牌)
"""
# 排序点数值
values = sorted([rank.numeric_value for rank in ranks], reverse=True) values = sorted([rank.numeric_value for rank in ranks], reverse=True)
# 检查常规顺子
is_regular_straight = True is_regular_straight = True
for i in range(1, len(values)): for i in range(1, len(values)):
if values[i-1] - values[i] != 1: if values[i-1] - values[i] != 1:
@@ -125,8 +109,7 @@ class HandEvaluator:
break break
return True, highest_rank return True, highest_rank
# 检查A-2-3-4-5的特殊顺子轮子
if values == [14, 5, 4, 3, 2]: # A, 5, 4, 3, 2 if values == [14, 5, 4, 3, 2]: # A, 5, 4, 3, 2
return True, Rank.FIVE # 在轮子中5是最高牌 return True, Rank.FIVE
return False, None return False, None

View File

@@ -1,16 +1,9 @@
"""
Hand ranking module for poker game
"""
from enum import Enum from enum import Enum
from typing import List, Tuple from typing import List, Tuple
from .card import Card, Rank from .card import Card, Rank
class HandType(Enum): class HandType(Enum):
"""
手牌类型枚举,按强度排序
"""
HIGH_CARD = (1, "High Card") HIGH_CARD = (1, "High Card")
ONE_PAIR = (2, "Pair") ONE_PAIR = (2, "Pair")
TWO_PAIR = (3, "Two Pair") TWO_PAIR = (3, "Two Pair")
@@ -29,44 +22,16 @@ class HandType(Enum):
obj.type_name = name obj.type_name = name
return obj return obj
def __str__(self):
return self.type_name
def __lt__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength < other.strength
def __le__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength <= other.strength
def __gt__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength > other.strength
def __ge__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength >= other.strength
class HandRanking: class HandRanking:
"""
手牌排名类,包含手牌类型和关键牌
"""
def __init__(self, hand_type: HandType, key_ranks: List[Rank], cards: List[Card]): def __init__(self, hand_type: HandType, key_ranks: List[Rank], cards: List[Card]):
self.hand_type = hand_type self.hand_type = hand_type
self.key_ranks = key_ranks # 用于比较的关键点数 self.key_ranks = key_ranks # 用于比较的关键点数
self.cards = cards # 组成这个排名的5张牌 self.cards = cards # 组成这个ranking的5张牌
def __str__(self): def __str__(self):
"""
返回手牌排名的字符串表示
"""
if self.hand_type == HandType.FOUR_OF_A_KIND: if self.hand_type == HandType.FOUR_OF_A_KIND:
return f"Quad({self.key_ranks[0].symbol})" return f"Quad({self.key_ranks[0].symbol})"
elif self.hand_type == HandType.FULL_HOUSE: elif self.hand_type == HandType.FULL_HOUSE:
@@ -88,38 +53,5 @@ class HandRanking:
return f"Two Pair({self.key_ranks[0].symbol} and {self.key_ranks[1].symbol})" return f"Two Pair({self.key_ranks[0].symbol} and {self.key_ranks[1].symbol})"
elif self.hand_type == HandType.ONE_PAIR: elif self.hand_type == HandType.ONE_PAIR:
return f"Pair({self.key_ranks[0].symbol})" return f"Pair({self.key_ranks[0].symbol})"
else: # HIGH_CARD else:
return f"High Card({self.key_ranks[0].symbol})" return f"High Card({self.key_ranks[0].symbol})"
def __repr__(self):
return f"HandRanking({self.hand_type}, {[r.symbol for r in self.key_ranks]})"
def __eq__(self, other):
if not isinstance(other, HandRanking):
return False
return (self.hand_type == other.hand_type and
self.key_ranks == other.key_ranks)
def __lt__(self, other):
if not isinstance(other, HandRanking):
return NotImplemented
# 首先比较手牌类型
if self.hand_type != other.hand_type:
return self.hand_type < other.hand_type
# 如果手牌类型相同,比较关键点数
for self_rank, other_rank in zip(self.key_ranks, other.key_ranks):
if self_rank != other_rank:
return self_rank < other_rank
return False # 完全相等
def __le__(self, other):
return self == other or self < other
def __gt__(self, other):
return not self <= other
def __ge__(self, other):
return not self < other

View File

@@ -7,132 +7,106 @@ from .card import Card, Rank, Suit
from .hand_evaluator import HandEvaluator from .hand_evaluator import HandEvaluator
from .hand_ranking import HandRanking, HandType from .hand_ranking import HandRanking, HandType
def parseTurnInput(input_text) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]:
def calEmd(Hist1: List[Union[int, float]],
Hist2: List[Union[int, float]]) -> float:
return wasserstein_distance(Hist1, Hist2)
def parseTurnInput(input_text: str) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]:
"""
解析TURN阶段输入
格式: "As Ks\n7s 6s\n8h 9d 9c Qh"
返回: (手牌1, 手牌2, 公共牌)
"""
lines = [line.strip() for line in input_text.strip().split('\n') if line.strip()] lines = [line.strip() for line in input_text.strip().split('\n') if line.strip()]
if len(lines) != 3: if len(lines) != 3:
print("输入必须包含恰好3行手牌1手牌2公共牌") print("输入必须包含恰好3行hand1hand2board cards")
return None, None, None return None, None, None
hand1_cards = Card.parse_cards(lines[0]) hand1_cards = Card.parseCards(lines[0])
hand2_cards = Card.parse_cards(lines[1]) hand2_cards = Card.parseCards(lines[1])
board_cards = Card.parse_cards(lines[2]) board_cards = Card.parseCards(lines[2])
if len(hand1_cards) != 2: if len(hand1_cards) != 2:
print("手牌1必须包含恰好2张牌") print(f"{hand1_cards} should contain exactly 2 cards")
return None, None, None return None, None, None
if len(hand2_cards) != 2: if len(hand2_cards) != 2:
print("手牌2必须包含恰好2张牌") print(f"{hand2_cards} should contain exactly 2 cards")
return None, None, None return None, None, None
if len(board_cards) != 4: if len(board_cards) != 4:
print("公共牌必须包含恰好4张牌TURN阶段") print(f"{board_cards} should contain exactly 4 cards")
return None, None, None return None, None, None
return hand1_cards, hand2_cards, board_cards return hand1_cards, hand2_cards, board_cards
def calHandEquityHist(hole_cards, board_cards) -> List[float]: def calHandEquityHist(hole_cards, board_cards) -> List[int]:
"""
计算TURN阶段手牌的胜率分布
基于每个可能河牌下的胜率创建分布,而不是绝对强度排名
"""
all_cards = hole_cards + board_cards # 总共6张牌 all_cards = hole_cards + board_cards # 总共6张牌
# 创建剩余牌组
used_cards = set(all_cards) used_cards = set(all_cards)
remaining_deck = [] remaining_cards = []
for rank in Rank: for rank in Rank:
for suit in Suit: for suit in Suit:
card = Card(rank, suit) card = Card(rank, suit)
if card not in used_cards: if card not in used_cards:
remaining_deck.append(card) remaining_cards.append(card)
# 为每个可能的河牌计算该手牌的胜率 # 为每个可能的河牌计算该手牌的胜率
winrates = [] winrates = []
for river_card in remaining_deck: for river_card in remaining_cards:
# 当前手牌在这个河牌下的最终七张牌 # 当前手牌在这个河牌下的最终七张牌
current_seven_cards = all_cards + [river_card] player1_7cards = all_cards + [river_card]
current_hand_ranking = HandEvaluator.evaluate_hand(current_seven_cards) player1_hand_ranking = HandEvaluator.evaluateHand(player1_7cards)
# 计算对抗所有可能对手牌的胜率 # 计算对抗所有可能对手牌的胜率
wins = 0 wins = 0
total_opponents = 0 total_wins = 0
# 生成所有可能的对手牌组合 # 生成所有可能的对手牌组合
used_cards_with_river = used_cards | {river_card} available_cards = [card for card in remaining_cards if card != river_card]
available_cards = [card for card in remaining_deck if card != river_card]
for i in range(len(available_cards)): for i in range(len(available_cards)):
for j in range(i + 1, len(available_cards)): for j in range(i + 1, len(available_cards)):
opponent_cards = [available_cards[i], available_cards[j]] player2_cards = [available_cards[i], available_cards[j]]
opponent_seven_cards = opponent_cards + board_cards + [river_card] player2_7cards = player2_cards + board_cards + [river_card]
opponent_hand_ranking = HandEvaluator.evaluate_hand(opponent_seven_cards) player2_hand_ranking = HandEvaluator.evaluateHand(player2_7cards)
total_opponents += 1 total_wins += 1
if current_hand_ranking > opponent_hand_ranking: if player1_hand_ranking > player2_hand_ranking:
wins += 1 wins += 1
elif current_hand_ranking == opponent_hand_ranking: elif player1_hand_ranking == player2_hand_ranking:
wins += 0.5 # 平局 wins += 0.5 # 平局
winrate = wins / total_opponents if total_opponents > 0 else 0.0 winrate = wins / total_wins if total_wins > 0 else 0.0
winrates.append(winrate) winrates.append(winrate)
num_bins = 30 num_bins = 30
hist = [0.0] * num_bins hist = [0] * num_bins
for winrate in winrates: for winrate in winrates:
bin_index = min(int(winrate * num_bins), num_bins - 1) bin_index = min(int(winrate * num_bins), num_bins - 1)
hist[bin_index] += 1.0 hist[bin_index] += 1
total = sum(hist)
if total > 0:
hist = [x / total for x in hist]
return hist return hist
def calPokerEmdTurn(input_text: str) -> Union[float, None]: def calPokerEmdTurn(input_text) -> Union[float, None]:
"""
计算TURN阶段两手牌的EMD距离
"""
hand1_cards, hand2_cards, board_cards = parseTurnInput(input_text) hand1_cards, hand2_cards, board_cards = parseTurnInput(input_text)
# 检查解析是否成功
if hand1_cards is None or hand2_cards is None or board_cards is None: if hand1_cards is None or hand2_cards is None or board_cards is None:
return None return None
# 计算胜率分布 # 计算胜率分布
dist1 = calHandEquityHist(hand1_cards, board_cards) hist1 = calHandEquityHist(hand1_cards, board_cards)
dist2 = calHandEquityHist(hand2_cards, board_cards) hist2 = calHandEquityHist(hand2_cards, board_cards)
# 计算EMD距离 # 计算EMD距离
emd_distance = calEmd(dist1, dist2) emd_distance = wasserstein_distance(hist1, hist2)
return emd_distance return emd_distance
def runTask3(): def runTask3(input_str):
"""运行任务3的示例输入""" if input_str is None:
example_input = """ input_str = """
AsKs AsKs
7s6s 7s6s
8h 9d 9c Qh""" 8h9d9cQh
"""
try: try:
distance = calPokerEmdTurn(example_input) distance = calPokerEmdTurn(input_str)
return distance return distance
except Exception as e: except Exception as e:
print(f"错误: {e}") print(f"错误: {e}")

View File

@@ -1,100 +0,0 @@
"""测试任务3 - 使用poker_task1进行扑克牌EMD距离计算"""
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# 从poker_task3.task3模块导入函数
from poker_task3.task3 import parseTurnInput, calHandEquityHist, calPokerEmdTurn
from poker_task3.card import Card
def test_parse_turn_input():
"""测试TURN阶段输入解析"""
input_text = """
As Ks
7s 6s
8h 9d 9c Qh
"""
hand1, hand2, board = parseTurnInput(input_text)
assert len(hand1) == 2
assert len(hand2) == 2
assert len(board) == 4
assert str(hand1[0]) == "As"
assert str(hand1[1]) == "Ks"
assert str(hand2[0]) == "7s"
assert str(hand2[1]) == "6s"
print("✓ TURN输入解析测试通过")
def test_hand_equity_Hist():
"""测试手牌胜率分布计算"""
hole_cards = Card.parse_cards("As Ks")
board_cards = Card.parse_cards("8h 9d 9c Qh")
Hist = calHandEquityHist(hole_cards, board_cards)
assert len(Hist) == 30 # 胜率分布的区间数量改进为30个区间
assert all(isinstance(x, (int, float)) for x in Hist)
assert all(x >= 0 for x in Hist) # 所有值都应该是非负的
assert abs(sum(Hist) - 1.0) < 1e-10 # 分布总和应该为1标准化后
print("✓ 手牌胜率分布测试通过")
def test_poker_emd_calculation():
"""测试主要的EMD计算函数"""
input_text = """As Ks
7s 6s
8h 9d 9c Qh"""
distance = calPokerEmdTurn(input_text)
assert isinstance(distance, (int, float))
assert distance >= 0 # EMD总是非负的
print(f"✓ 扑克牌EMD计算测试通过 (距离: {distance:.3f})")
def test_different_hand_strengths():
"""测试不同强度手牌的EMD"""
# 强牌 vs 弱牌
input_text = """As Ks
2c 3d
8h 9d 9c Qh"""
distance = calPokerEmdTurn(input_text)
assert isinstance(distance, (int, float))
assert distance >= 0 # 应该有非负距离
print(f"✓ 不同手牌强度测试通过 (距离: {distance:.3f})")
def test_error_handling():
"""测试无效输入的错误处理"""
# 错误的行数
result = parseTurnInput("As Ks\n7s 6s")
assert result == (None, None, None), "应该返回(None, None, None)"
# 手牌中牌数错误
result = parseTurnInput("As Ks Qs\n7s 6s\n8h 9d 9c Qh")
assert result == (None, None, None), "应该返回(None, None, None)"
# 公共牌数错误
result = parseTurnInput("As Ks\n7s 6s\n8h 9d 9c")
assert result == (None, None, None), "应该返回(None, None, None)"
print("✓ 错误处理测试通过")
if __name__ == "__main__":
test_parse_turn_input()
test_hand_equity_Hist()
test_poker_emd_calculation()
test_different_hand_strengths()
test_error_handling()
print("所有任务3测试通过! ✓")