This commit is contained in:
2025-09-23 09:15:05 +08:00
parent 6e4973da7a
commit e8a705c467
7 changed files with 72 additions and 353 deletions

View File

@@ -1,24 +1,7 @@
# Poker Task 3: EMD Distance Calculation on TURN
# Poker Task3
这是扑克Task 3的实现专注于计算TURN阶段两手牌力分布的EMD (Earth Mover's Distance) 距离。
这是Task3的实现计算TURN阶段的EMD (Earth Mover's Distance) 距离。
## 功能特点
- **TURN阶段扑克分析**: 基于已知的公共牌和底牌计算每手牌在所有可能RIVER牌下的牌力分布
- **EMD距离计算**: 使用scipy的wasserstein_distance函数计算两个牌力分布的EMD距离
- **专业扑克评估**: 集成poker_task1的专业扑克牌评估系统支持所有标准扑克手牌类型
## 安装和使用
### 安装依赖
```bash
# 使用uv管理依赖推荐
uv install
# 或使用pip
pip install scipy
```
### 运行程序
@@ -32,57 +15,9 @@ uv run python main.py
### 输入格式
程序接受以下格式的输入:
```
玩家1底牌两张
玩家2底牌两张
公共牌TURN阶段4张牌
```
示例:
```
As Ks
7s 6s
8h 9d 9c Qh
AsKs
7s6s
8h9d9cQh
```
## 运行测试
```bash
# 运行所有测试
uv run python -m pytest tests/ -v
# 或
python -m pytest tests/ -v
```
## 项目结构
```
poker_task3/
├── poker_task3/ # 主要代码包
│ ├── card.py # 扑克牌类定义
│ ├── hand_evaluator.py # 手牌评估器
│ ├── hand_ranking.py # 手牌排名系统
│ ├── emd.py # EMD距离计算
│ ├── task3.py # Task 3主要实现
│ └── __init__.py
├── tests/ # 测试文件
│ └── test_task3.py # Task 3测试
├── main.py # 程序入口
└── README.md # 项目文档
```
## 算法原理
1. **牌力评估**: 对于每手底牌遍历所有可能的RIVER牌剩余47张牌
2. **分布构建**: 统计每种手牌类型High Card到Royal Flush的出现次数
3. **EMD计算**: 使用scipy.stats.wasserstein_distance计算两个分布的EMD距离
## 测试覆盖
- 输入解析测试
- 牌力分布计算测试
- EMD距离计算测试
- 错误处理测试
- 不同牌力强度对比测试

12
main.py
View File

@@ -4,10 +4,16 @@ from poker_task3.task3 import runTask3
def main():
"""task3的主入口点"""
print("=== task3扑克牌TURN阶段EMD距离计算 ===")
distance = runTask3()
print("请输入三行数据:")
print("第一行hand1如 AsKs")
print("第二行hand2如 7s6s")
print("第三行board cards如 8h9d9cQh")
lines = []
for i in range(3):
lines.append(input())
input_text = "\n".join(lines)
distance = runTask3(input_text)
if distance is not None:
print(f"EMD距离: {distance}")

View File

@@ -1,7 +1,3 @@
"""
Card module for poker game
"""
from enum import Enum
from typing import List, Tuple, Optional
@@ -87,10 +83,7 @@ class Card:
return self.rank.numeric_value < other.rank.numeric_value
@classmethod
def create_card(cls, card_str: str) -> 'Card':
"""
从字符串创建Card对象例如 "As", "Kh", "2c"
"""
def createCard(cls, card_str) -> 'Card':
if len(card_str) != 2:
raise ValueError(f"Invalid card string: {card_str}")
@@ -120,7 +113,7 @@ class Card:
return cls(rank, suit)
@classmethod
def parse_cards(cls, cards_str: str) -> List['Card']:
def parseCards(cls, cards_str) -> List['Card']:
"""
从字符串解析多张牌,例如 "AsKs AhAdAc6s7s"
"""
@@ -142,8 +135,4 @@ class Card:
i += 2
else:
raise ValueError(f"Invalid card format at position {i}")
result = []
for card_str in card_strings:
card_tmp = cls.create_card(card_str)
result.append(card_tmp)
return result
return [cls.createCard(card_str) for card_str in card_strings]

View File

@@ -1,7 +1,3 @@
"""
Hand evaluator module for poker game
"""
from typing import List, Tuple, Dict
from collections import Counter
from itertools import combinations
@@ -10,14 +6,10 @@ from .hand_ranking import HandRanking, HandType
class HandEvaluator:
"""
手牌评估器类,用于评估扑克手牌
"""
@staticmethod
def evaluate_hand(cards) -> HandRanking:
def evaluateHand(cards) -> HandRanking:
"""
从7张牌中评估出最好的5张牌组合
从7张牌中出最好的5张牌组合
"""
if len(cards) != 7:
raise ValueError(f"Expected 7 cards, got {len(cards)}")
@@ -25,23 +17,19 @@ class HandEvaluator:
best_ranking = None
best_cards = None
# 尝试所有可能的5张牌组合
# 所有可能的5张牌组合
for five_cards in combinations(cards, 5):
ranking = HandEvaluator._evaluate_five_cards(list(five_cards))
ranking = HandEvaluator.evaluate5Cards(list(five_cards))
if best_ranking is None or ranking > best_ranking:
best_ranking = ranking
best_cards = list(five_cards)
# 更新最佳ranking的cards
best_ranking.cards = best_cards
return best_ranking
@staticmethod
def _evaluate_five_cards(cards: List[Card]) -> HandRanking:
"""
评估5张牌的手牌类型
"""
def evaluate5Cards(cards) -> HandRanking:
if len(cards) != 5:
raise ValueError(f"Expected 5 cards, got {len(cards)}")
@@ -54,11 +42,11 @@ class HandEvaluator:
rank_counts = Counter(ranks)
count_values = sorted(rank_counts.values(), reverse=True)
# 检查是否是同花
# 同花
is_flush = len(set(suits)) == 1
# 检查是否是顺子
is_straight, straight_high = HandEvaluator._is_straight(ranks)
# 顺子
is_straight, straight_high = HandEvaluator._isStraight(ranks)
# 根据牌型返回相应的HandRanking
if is_straight and is_flush:
@@ -102,14 +90,10 @@ class HandEvaluator:
return HandRanking(HandType.HIGH_CARD, ranks, sorted_cards)
@staticmethod
def _is_straight(ranks: List[Rank]) -> Tuple[bool, Rank]:
"""
检查是否是顺子,返回(是否是顺子, 最高牌)
"""
# 排序点数值
def _isStraight(ranks: List[Rank]) -> Tuple[bool, Rank]:
values = sorted([rank.numeric_value for rank in ranks], reverse=True)
# 检查常规顺子
is_regular_straight = True
for i in range(1, len(values)):
if values[i-1] - values[i] != 1:
@@ -125,8 +109,7 @@ class HandEvaluator:
break
return True, highest_rank
# 检查A-2-3-4-5的特殊顺子轮子
if values == [14, 5, 4, 3, 2]: # A, 5, 4, 3, 2
return True, Rank.FIVE # 在轮子中5是最高牌
return True, Rank.FIVE
return False, None

View File

@@ -1,16 +1,9 @@
"""
Hand ranking module for poker game
"""
from enum import Enum
from typing import List, Tuple
from .card import Card, Rank
class HandType(Enum):
"""
手牌类型枚举,按强度排序
"""
HIGH_CARD = (1, "High Card")
ONE_PAIR = (2, "Pair")
TWO_PAIR = (3, "Two Pair")
@@ -29,44 +22,16 @@ class HandType(Enum):
obj.type_name = name
return obj
def __str__(self):
return self.type_name
def __lt__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength < other.strength
def __le__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength <= other.strength
def __gt__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength > other.strength
def __ge__(self, other):
if not isinstance(other, HandType):
return NotImplemented
return self.strength >= other.strength
class HandRanking:
"""
手牌排名类,包含手牌类型和关键牌
"""
def __init__(self, hand_type: HandType, key_ranks: List[Rank], cards: List[Card]):
self.hand_type = hand_type
self.key_ranks = key_ranks # 用于比较的关键点数
self.cards = cards # 组成这个排名的5张牌
self.cards = cards # 组成这个ranking的5张牌
def __str__(self):
"""
返回手牌排名的字符串表示
"""
if self.hand_type == HandType.FOUR_OF_A_KIND:
return f"Quad({self.key_ranks[0].symbol})"
elif self.hand_type == HandType.FULL_HOUSE:
@@ -88,38 +53,5 @@ class HandRanking:
return f"Two Pair({self.key_ranks[0].symbol} and {self.key_ranks[1].symbol})"
elif self.hand_type == HandType.ONE_PAIR:
return f"Pair({self.key_ranks[0].symbol})"
else: # HIGH_CARD
else:
return f"High Card({self.key_ranks[0].symbol})"
def __repr__(self):
return f"HandRanking({self.hand_type}, {[r.symbol for r in self.key_ranks]})"
def __eq__(self, other):
if not isinstance(other, HandRanking):
return False
return (self.hand_type == other.hand_type and
self.key_ranks == other.key_ranks)
def __lt__(self, other):
if not isinstance(other, HandRanking):
return NotImplemented
# 首先比较手牌类型
if self.hand_type != other.hand_type:
return self.hand_type < other.hand_type
# 如果手牌类型相同,比较关键点数
for self_rank, other_rank in zip(self.key_ranks, other.key_ranks):
if self_rank != other_rank:
return self_rank < other_rank
return False # 完全相等
def __le__(self, other):
return self == other or self < other
def __gt__(self, other):
return not self <= other
def __ge__(self, other):
return not self < other

View File

@@ -7,132 +7,106 @@ from .card import Card, Rank, Suit
from .hand_evaluator import HandEvaluator
from .hand_ranking import HandRanking, HandType
def calEmd(Hist1: List[Union[int, float]],
Hist2: List[Union[int, float]]) -> float:
return wasserstein_distance(Hist1, Hist2)
def parseTurnInput(input_text: str) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]:
"""
解析TURN阶段输入
格式: "As Ks\n7s 6s\n8h 9d 9c Qh"
返回: (手牌1, 手牌2, 公共牌)
"""
def parseTurnInput(input_text) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]:
lines = [line.strip() for line in input_text.strip().split('\n') if line.strip()]
if len(lines) != 3:
print("输入必须包含恰好3行手牌1手牌2公共牌")
print("输入必须包含恰好3行hand1hand2board cards")
return None, None, None
hand1_cards = Card.parse_cards(lines[0])
hand2_cards = Card.parse_cards(lines[1])
board_cards = Card.parse_cards(lines[2])
hand1_cards = Card.parseCards(lines[0])
hand2_cards = Card.parseCards(lines[1])
board_cards = Card.parseCards(lines[2])
if len(hand1_cards) != 2:
print("手牌1必须包含恰好2张牌")
print(f"{hand1_cards} should contain exactly 2 cards")
return None, None, None
if len(hand2_cards) != 2:
print("手牌2必须包含恰好2张牌")
print(f"{hand2_cards} should contain exactly 2 cards")
return None, None, None
if len(board_cards) != 4:
print("公共牌必须包含恰好4张牌TURN阶段")
print(f"{board_cards} should contain exactly 4 cards")
return None, None, None
return hand1_cards, hand2_cards, board_cards
def calHandEquityHist(hole_cards, board_cards) -> List[float]:
"""
计算TURN阶段手牌的胜率分布
基于每个可能河牌下的胜率创建分布,而不是绝对强度排名
"""
def calHandEquityHist(hole_cards, board_cards) -> List[int]:
all_cards = hole_cards + board_cards # 总共6张牌
# 创建剩余牌组
used_cards = set(all_cards)
remaining_deck = []
remaining_cards = []
for rank in Rank:
for suit in Suit:
card = Card(rank, suit)
if card not in used_cards:
remaining_deck.append(card)
remaining_cards.append(card)
# 为每个可能的河牌计算该手牌的胜率
winrates = []
for river_card in remaining_deck:
for river_card in remaining_cards:
# 当前手牌在这个河牌下的最终七张牌
current_seven_cards = all_cards + [river_card]
current_hand_ranking = HandEvaluator.evaluate_hand(current_seven_cards)
player1_7cards = all_cards + [river_card]
player1_hand_ranking = HandEvaluator.evaluateHand(player1_7cards)
# 计算对抗所有可能对手牌的胜率
wins = 0
total_opponents = 0
total_wins = 0
# 生成所有可能的对手牌组合
used_cards_with_river = used_cards | {river_card}
available_cards = [card for card in remaining_deck if card != river_card]
# 生成所有可能的对手牌组合
available_cards = [card for card in remaining_cards if card != river_card]
for i in range(len(available_cards)):
for j in range(i + 1, len(available_cards)):
opponent_cards = [available_cards[i], available_cards[j]]
opponent_seven_cards = opponent_cards + board_cards + [river_card]
opponent_hand_ranking = HandEvaluator.evaluate_hand(opponent_seven_cards)
player2_cards = [available_cards[i], available_cards[j]]
player2_7cards = player2_cards + board_cards + [river_card]
player2_hand_ranking = HandEvaluator.evaluateHand(player2_7cards)
total_opponents += 1
if current_hand_ranking > opponent_hand_ranking:
total_wins += 1
if player1_hand_ranking > player2_hand_ranking:
wins += 1
elif current_hand_ranking == opponent_hand_ranking:
elif player1_hand_ranking == player2_hand_ranking:
wins += 0.5 # 平局
winrate = wins / total_opponents if total_opponents > 0 else 0.0
winrate = wins / total_wins if total_wins > 0 else 0.0
winrates.append(winrate)
num_bins = 30
hist = [0.0] * num_bins
hist = [0] * num_bins
for winrate in winrates:
bin_index = min(int(winrate * num_bins), num_bins - 1)
hist[bin_index] += 1.0
total = sum(hist)
if total > 0:
hist = [x / total for x in hist]
hist[bin_index] += 1
return hist
def calPokerEmdTurn(input_text: str) -> Union[float, None]:
"""
计算TURN阶段两手牌的EMD距离
"""
def calPokerEmdTurn(input_text) -> Union[float, None]:
hand1_cards, hand2_cards, board_cards = parseTurnInput(input_text)
# 检查解析是否成功
if hand1_cards is None or hand2_cards is None or board_cards is None:
return None
# 计算胜率分布
dist1 = calHandEquityHist(hand1_cards, board_cards)
dist2 = calHandEquityHist(hand2_cards, board_cards)
hist1 = calHandEquityHist(hand1_cards, board_cards)
hist2 = calHandEquityHist(hand2_cards, board_cards)
# 计算EMD距离
emd_distance = calEmd(dist1, dist2)
emd_distance = wasserstein_distance(hist1, hist2)
return emd_distance
def runTask3():
"""运行任务3的示例输入"""
example_input = """
As Ks
7s 6s
8h 9d 9c Qh"""
def runTask3(input_str):
if input_str is None:
input_str = """
AsKs
7s6s
8h9d9cQh
"""
try:
distance = calPokerEmdTurn(example_input)
distance = calPokerEmdTurn(input_str)
return distance
except Exception as e:
print(f"错误: {e}")

View File

@@ -1,100 +0,0 @@
"""测试任务3 - 使用poker_task1进行扑克牌EMD距离计算"""
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# 从poker_task3.task3模块导入函数
from poker_task3.task3 import parseTurnInput, calHandEquityHist, calPokerEmdTurn
from poker_task3.card import Card
def test_parse_turn_input():
"""测试TURN阶段输入解析"""
input_text = """
As Ks
7s 6s
8h 9d 9c Qh
"""
hand1, hand2, board = parseTurnInput(input_text)
assert len(hand1) == 2
assert len(hand2) == 2
assert len(board) == 4
assert str(hand1[0]) == "As"
assert str(hand1[1]) == "Ks"
assert str(hand2[0]) == "7s"
assert str(hand2[1]) == "6s"
print("✓ TURN输入解析测试通过")
def test_hand_equity_Hist():
"""测试手牌胜率分布计算"""
hole_cards = Card.parse_cards("As Ks")
board_cards = Card.parse_cards("8h 9d 9c Qh")
Hist = calHandEquityHist(hole_cards, board_cards)
assert len(Hist) == 30 # 胜率分布的区间数量改进为30个区间
assert all(isinstance(x, (int, float)) for x in Hist)
assert all(x >= 0 for x in Hist) # 所有值都应该是非负的
assert abs(sum(Hist) - 1.0) < 1e-10 # 分布总和应该为1标准化后
print("✓ 手牌胜率分布测试通过")
def test_poker_emd_calculation():
"""测试主要的EMD计算函数"""
input_text = """As Ks
7s 6s
8h 9d 9c Qh"""
distance = calPokerEmdTurn(input_text)
assert isinstance(distance, (int, float))
assert distance >= 0 # EMD总是非负的
print(f"✓ 扑克牌EMD计算测试通过 (距离: {distance:.3f})")
def test_different_hand_strengths():
"""测试不同强度手牌的EMD"""
# 强牌 vs 弱牌
input_text = """As Ks
2c 3d
8h 9d 9c Qh"""
distance = calPokerEmdTurn(input_text)
assert isinstance(distance, (int, float))
assert distance >= 0 # 应该有非负距离
print(f"✓ 不同手牌强度测试通过 (距离: {distance:.3f})")
def test_error_handling():
"""测试无效输入的错误处理"""
# 错误的行数
result = parseTurnInput("As Ks\n7s 6s")
assert result == (None, None, None), "应该返回(None, None, None)"
# 手牌中牌数错误
result = parseTurnInput("As Ks Qs\n7s 6s\n8h 9d 9c Qh")
assert result == (None, None, None), "应该返回(None, None, None)"
# 公共牌数错误
result = parseTurnInput("As Ks\n7s 6s\n8h 9d 9c")
assert result == (None, None, None), "应该返回(None, None, None)"
print("✓ 错误处理测试通过")
if __name__ == "__main__":
test_parse_turn_input()
test_hand_equity_Hist()
test_poker_emd_calculation()
test_different_hand_strengths()
test_error_handling()
print("所有任务3测试通过! ✓")