From e8a705c46712a2f4293836e64511bdad856f28f7 Mon Sep 17 00:00:00 2001 From: jianghaiying Date: Tue, 23 Sep 2025 09:15:05 +0800 Subject: [PATCH] task3 --- README.md | 77 ++---------------------- main.py | 12 +++- poker_task3/card.py | 17 +----- poker_task3/hand_evaluator.py | 41 ++++--------- poker_task3/hand_ranking.py | 72 +---------------------- poker_task3/task3.py | 106 +++++++++++++--------------------- tests/test_task3.py | 100 -------------------------------- 7 files changed, 72 insertions(+), 353 deletions(-) delete mode 100644 tests/test_task3.py diff --git a/README.md b/README.md index 35cd721..55d7d81 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,7 @@ -# Poker Task 3: EMD Distance Calculation on TURN +# Poker Task3 -这是扑克Task 3的实现,专注于计算TURN阶段两手牌力分布的EMD (Earth Mover's Distance) 距离。 +这是Task3的实现,计算TURN阶段的EMD (Earth Mover's Distance) 距离。 -## 功能特点 - -- **TURN阶段扑克分析**: 基于已知的公共牌和底牌,计算每手牌在所有可能RIVER牌下的牌力分布 -- **EMD距离计算**: 使用scipy的wasserstein_distance函数计算两个牌力分布的EMD距离 -- **专业扑克评估**: 集成poker_task1的专业扑克牌评估系统,支持所有标准扑克手牌类型 - -## 安装和使用 - -### 安装依赖 - -```bash -# 使用uv管理依赖(推荐) -uv install - -# 或使用pip -pip install scipy -``` ### 运行程序 @@ -32,57 +15,9 @@ uv run python main.py ### 输入格式 -程序接受以下格式的输入: -``` -玩家1底牌(两张) -玩家2底牌(两张) -公共牌(TURN阶段,4张牌) -``` - 示例: ``` -As Ks -7s 6s -8h 9d 9c Qh -``` - -## 运行测试 - -```bash -# 运行所有测试 -uv run python -m pytest tests/ -v - -# 或 -python -m pytest tests/ -v -``` - -## 项目结构 - -``` -poker_task3/ -├── poker_task3/ # 主要代码包 -│ ├── card.py # 扑克牌类定义 -│ ├── hand_evaluator.py # 手牌评估器 -│ ├── hand_ranking.py # 手牌排名系统 -│ ├── emd.py # EMD距离计算 -│ ├── task3.py # Task 3主要实现 -│ └── __init__.py -├── tests/ # 测试文件 -│ └── test_task3.py # Task 3测试 -├── main.py # 程序入口 -└── README.md # 项目文档 -``` - -## 算法原理 - -1. **牌力评估**: 对于每手底牌,遍历所有可能的RIVER牌(剩余47张牌) -2. **分布构建**: 统计每种手牌类型(High Card到Royal Flush)的出现次数 -3. **EMD计算**: 使用scipy.stats.wasserstein_distance计算两个分布的EMD距离 - -## 测试覆盖 - -- 输入解析测试 -- 牌力分布计算测试 -- EMD距离计算测试 -- 错误处理测试 -- 不同牌力强度对比测试 \ No newline at end of file +AsKs +7s6s +8h9d9cQh +``` \ No newline at end of file diff --git a/main.py b/main.py index 2798e3e..776c382 100644 --- a/main.py +++ b/main.py @@ -4,10 +4,16 @@ from poker_task3.task3 import runTask3 def main(): - """task3的主入口点""" print("=== task3:扑克牌TURN阶段EMD距离计算 ===") - - distance = runTask3() + print("请输入三行数据:") + print("第一行:hand1(如 AsKs)") + print("第二行:hand2(如 7s6s)") + print("第三行:board cards(如 8h9d9cQh)") + lines = [] + for i in range(3): + lines.append(input()) + input_text = "\n".join(lines) + distance = runTask3(input_text) if distance is not None: print(f"EMD距离: {distance}") diff --git a/poker_task3/card.py b/poker_task3/card.py index 52e3d4c..a419932 100644 --- a/poker_task3/card.py +++ b/poker_task3/card.py @@ -1,7 +1,3 @@ -""" -Card module for poker game -""" - from enum import Enum from typing import List, Tuple, Optional @@ -87,10 +83,7 @@ class Card: return self.rank.numeric_value < other.rank.numeric_value @classmethod - def create_card(cls, card_str: str) -> 'Card': - """ - 从字符串创建Card对象,例如 "As", "Kh", "2c" - """ + def createCard(cls, card_str) -> 'Card': if len(card_str) != 2: raise ValueError(f"Invalid card string: {card_str}") @@ -120,7 +113,7 @@ class Card: return cls(rank, suit) @classmethod - def parse_cards(cls, cards_str: str) -> List['Card']: + def parseCards(cls, cards_str) -> List['Card']: """ 从字符串解析多张牌,例如 "AsKs AhAdAc6s7s" """ @@ -142,8 +135,4 @@ class Card: i += 2 else: raise ValueError(f"Invalid card format at position {i}") - result = [] - for card_str in card_strings: - card_tmp = cls.create_card(card_str) - result.append(card_tmp) - return result \ No newline at end of file + return [cls.createCard(card_str) for card_str in card_strings] \ No newline at end of file diff --git a/poker_task3/hand_evaluator.py b/poker_task3/hand_evaluator.py index 14093cc..455f884 100644 --- a/poker_task3/hand_evaluator.py +++ b/poker_task3/hand_evaluator.py @@ -1,7 +1,3 @@ -""" -Hand evaluator module for poker game -""" - from typing import List, Tuple, Dict from collections import Counter from itertools import combinations @@ -10,14 +6,10 @@ from .hand_ranking import HandRanking, HandType class HandEvaluator: - """ - 手牌评估器类,用于评估扑克手牌 - """ - @staticmethod - def evaluate_hand(cards) -> HandRanking: + def evaluateHand(cards) -> HandRanking: """ - 从7张牌中评估出最好的5张牌组合 + 从7张牌中找出最好的5张牌组合 """ if len(cards) != 7: raise ValueError(f"Expected 7 cards, got {len(cards)}") @@ -25,23 +17,19 @@ class HandEvaluator: best_ranking = None best_cards = None - # 尝试所有可能的5张牌组合 + # 所有可能的5张牌组合 for five_cards in combinations(cards, 5): - ranking = HandEvaluator._evaluate_five_cards(list(five_cards)) + ranking = HandEvaluator.evaluate5Cards(list(five_cards)) if best_ranking is None or ranking > best_ranking: best_ranking = ranking best_cards = list(five_cards) - - # 更新最佳ranking的cards best_ranking.cards = best_cards return best_ranking @staticmethod - def _evaluate_five_cards(cards: List[Card]) -> HandRanking: - """ - 评估5张牌的手牌类型 - """ + def evaluate5Cards(cards) -> HandRanking: + if len(cards) != 5: raise ValueError(f"Expected 5 cards, got {len(cards)}") @@ -54,11 +42,11 @@ class HandEvaluator: rank_counts = Counter(ranks) count_values = sorted(rank_counts.values(), reverse=True) - # 检查是否是同花 + # 同花 is_flush = len(set(suits)) == 1 - # 检查是否是顺子 - is_straight, straight_high = HandEvaluator._is_straight(ranks) + # 顺子 + is_straight, straight_high = HandEvaluator._isStraight(ranks) # 根据牌型返回相应的HandRanking if is_straight and is_flush: @@ -102,14 +90,10 @@ class HandEvaluator: return HandRanking(HandType.HIGH_CARD, ranks, sorted_cards) @staticmethod - def _is_straight(ranks: List[Rank]) -> Tuple[bool, Rank]: - """ - 检查是否是顺子,返回(是否是顺子, 最高牌) - """ - # 排序点数值 + def _isStraight(ranks: List[Rank]) -> Tuple[bool, Rank]: + values = sorted([rank.numeric_value for rank in ranks], reverse=True) - # 检查常规顺子 is_regular_straight = True for i in range(1, len(values)): if values[i-1] - values[i] != 1: @@ -125,8 +109,7 @@ class HandEvaluator: break return True, highest_rank - # 检查A-2-3-4-5的特殊顺子(轮子) if values == [14, 5, 4, 3, 2]: # A, 5, 4, 3, 2 - return True, Rank.FIVE # 在轮子中,5是最高牌 + return True, Rank.FIVE return False, None \ No newline at end of file diff --git a/poker_task3/hand_ranking.py b/poker_task3/hand_ranking.py index 199321b..9a29768 100644 --- a/poker_task3/hand_ranking.py +++ b/poker_task3/hand_ranking.py @@ -1,16 +1,9 @@ -""" -Hand ranking module for poker game -""" - from enum import Enum from typing import List, Tuple from .card import Card, Rank class HandType(Enum): - """ - 手牌类型枚举,按强度排序 - """ HIGH_CARD = (1, "High Card") ONE_PAIR = (2, "Pair") TWO_PAIR = (3, "Two Pair") @@ -29,44 +22,16 @@ class HandType(Enum): obj.type_name = name return obj - def __str__(self): - return self.type_name - - def __lt__(self, other): - if not isinstance(other, HandType): - return NotImplemented - return self.strength < other.strength - - def __le__(self, other): - if not isinstance(other, HandType): - return NotImplemented - return self.strength <= other.strength - - def __gt__(self, other): - if not isinstance(other, HandType): - return NotImplemented - return self.strength > other.strength - - def __ge__(self, other): - if not isinstance(other, HandType): - return NotImplemented - return self.strength >= other.strength class HandRanking: - """ - 手牌排名类,包含手牌类型和关键牌 - """ def __init__(self, hand_type: HandType, key_ranks: List[Rank], cards: List[Card]): self.hand_type = hand_type self.key_ranks = key_ranks # 用于比较的关键点数 - self.cards = cards # 组成这个排名的5张牌 + self.cards = cards # 组成这个ranking的5张牌 def __str__(self): - """ - 返回手牌排名的字符串表示 - """ if self.hand_type == HandType.FOUR_OF_A_KIND: return f"Quad({self.key_ranks[0].symbol})" elif self.hand_type == HandType.FULL_HOUSE: @@ -88,38 +53,5 @@ class HandRanking: return f"Two Pair({self.key_ranks[0].symbol} and {self.key_ranks[1].symbol})" elif self.hand_type == HandType.ONE_PAIR: return f"Pair({self.key_ranks[0].symbol})" - else: # HIGH_CARD + else: return f"High Card({self.key_ranks[0].symbol})" - - def __repr__(self): - return f"HandRanking({self.hand_type}, {[r.symbol for r in self.key_ranks]})" - - def __eq__(self, other): - if not isinstance(other, HandRanking): - return False - return (self.hand_type == other.hand_type and - self.key_ranks == other.key_ranks) - - def __lt__(self, other): - if not isinstance(other, HandRanking): - return NotImplemented - - # 首先比较手牌类型 - if self.hand_type != other.hand_type: - return self.hand_type < other.hand_type - - # 如果手牌类型相同,比较关键点数 - for self_rank, other_rank in zip(self.key_ranks, other.key_ranks): - if self_rank != other_rank: - return self_rank < other_rank - - return False # 完全相等 - - def __le__(self, other): - return self == other or self < other - - def __gt__(self, other): - return not self <= other - - def __ge__(self, other): - return not self < other \ No newline at end of file diff --git a/poker_task3/task3.py b/poker_task3/task3.py index d8cc321..b2a15bf 100644 --- a/poker_task3/task3.py +++ b/poker_task3/task3.py @@ -7,132 +7,106 @@ from .card import Card, Rank, Suit from .hand_evaluator import HandEvaluator from .hand_ranking import HandRanking, HandType - -def calEmd(Hist1: List[Union[int, float]], - Hist2: List[Union[int, float]]) -> float: - return wasserstein_distance(Hist1, Hist2) - - -def parseTurnInput(input_text: str) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]: - """ - 解析TURN阶段输入 - 格式: "As Ks\n7s 6s\n8h 9d 9c Qh" - 返回: (手牌1, 手牌2, 公共牌) - """ +def parseTurnInput(input_text) -> Tuple[Union[List[Card], None], Union[List[Card], None], Union[List[Card], None]]: lines = [line.strip() for line in input_text.strip().split('\n') if line.strip()] if len(lines) != 3: - print("输入必须包含恰好3行:手牌1,手牌2,公共牌") + print("输入必须包含恰好3行:hand1,hand2,board cards") return None, None, None - hand1_cards = Card.parse_cards(lines[0]) - hand2_cards = Card.parse_cards(lines[1]) - board_cards = Card.parse_cards(lines[2]) + hand1_cards = Card.parseCards(lines[0]) + hand2_cards = Card.parseCards(lines[1]) + board_cards = Card.parseCards(lines[2]) if len(hand1_cards) != 2: - print("手牌1必须包含恰好2张牌") + print(f"{hand1_cards} should contain exactly 2 cards") return None, None, None if len(hand2_cards) != 2: - print("手牌2必须包含恰好2张牌") + print(f"{hand2_cards} should contain exactly 2 cards") return None, None, None if len(board_cards) != 4: - print("公共牌必须包含恰好4张牌(TURN阶段)") + print(f"{board_cards} should contain exactly 4 cards") return None, None, None - return hand1_cards, hand2_cards, board_cards -def calHandEquityHist(hole_cards, board_cards) -> List[float]: - """ - 计算TURN阶段手牌的胜率分布 - 基于每个可能河牌下的胜率创建分布,而不是绝对强度排名 - """ +def calHandEquityHist(hole_cards, board_cards) -> List[int]: all_cards = hole_cards + board_cards # 总共6张牌 - # 创建剩余牌组 used_cards = set(all_cards) - remaining_deck = [] + remaining_cards = [] for rank in Rank: for suit in Suit: card = Card(rank, suit) if card not in used_cards: - remaining_deck.append(card) + remaining_cards.append(card) # 为每个可能的河牌计算该手牌的胜率 winrates = [] - - for river_card in remaining_deck: + + for river_card in remaining_cards: # 当前手牌在这个河牌下的最终七张牌 - current_seven_cards = all_cards + [river_card] - current_hand_ranking = HandEvaluator.evaluate_hand(current_seven_cards) + player1_7cards = all_cards + [river_card] + player1_hand_ranking = HandEvaluator.evaluateHand(player1_7cards) # 计算对抗所有可能对手牌的胜率 wins = 0 - total_opponents = 0 + total_wins = 0 - # 生成所有可能的对手双牌组合 - used_cards_with_river = used_cards | {river_card} - available_cards = [card for card in remaining_deck if card != river_card] + # 生成所有可能的对手牌组合 + available_cards = [card for card in remaining_cards if card != river_card] for i in range(len(available_cards)): for j in range(i + 1, len(available_cards)): - opponent_cards = [available_cards[i], available_cards[j]] - opponent_seven_cards = opponent_cards + board_cards + [river_card] - opponent_hand_ranking = HandEvaluator.evaluate_hand(opponent_seven_cards) + player2_cards = [available_cards[i], available_cards[j]] + player2_7cards = player2_cards + board_cards + [river_card] + player2_hand_ranking = HandEvaluator.evaluateHand(player2_7cards) - total_opponents += 1 - if current_hand_ranking > opponent_hand_ranking: + total_wins += 1 + if player1_hand_ranking > player2_hand_ranking: wins += 1 - elif current_hand_ranking == opponent_hand_ranking: + elif player1_hand_ranking == player2_hand_ranking: wins += 0.5 # 平局 - winrate = wins / total_opponents if total_opponents > 0 else 0.0 + winrate = wins / total_wins if total_wins > 0 else 0.0 winrates.append(winrate) num_bins = 30 - hist = [0.0] * num_bins + hist = [0] * num_bins for winrate in winrates: bin_index = min(int(winrate * num_bins), num_bins - 1) - hist[bin_index] += 1.0 + hist[bin_index] += 1 - total = sum(hist) - if total > 0: - hist = [x / total for x in hist] - return hist -def calPokerEmdTurn(input_text: str) -> Union[float, None]: - """ - 计算TURN阶段两手牌的EMD距离 - """ +def calPokerEmdTurn(input_text) -> Union[float, None]: hand1_cards, hand2_cards, board_cards = parseTurnInput(input_text) - - # 检查解析是否成功 + if hand1_cards is None or hand2_cards is None or board_cards is None: return None # 计算胜率分布 - dist1 = calHandEquityHist(hand1_cards, board_cards) - dist2 = calHandEquityHist(hand2_cards, board_cards) + hist1 = calHandEquityHist(hand1_cards, board_cards) + hist2 = calHandEquityHist(hand2_cards, board_cards) # 计算EMD距离 - emd_distance = calEmd(dist1, dist2) + emd_distance = wasserstein_distance(hist1, hist2) return emd_distance -def runTask3(): - """运行任务3的示例输入""" - example_input = """ - As Ks - 7s 6s - 8h 9d 9c Qh""" - +def runTask3(input_str): + if input_str is None: + input_str = """ + AsKs + 7s6s + 8h9d9cQh + """ try: - distance = calPokerEmdTurn(example_input) + distance = calPokerEmdTurn(input_str) return distance except Exception as e: print(f"错误: {e}") diff --git a/tests/test_task3.py b/tests/test_task3.py deleted file mode 100644 index 156391e..0000000 --- a/tests/test_task3.py +++ /dev/null @@ -1,100 +0,0 @@ -"""测试任务3 - 使用poker_task1进行扑克牌EMD距离计算""" -import sys -import os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -# 从poker_task3.task3模块导入函数 -from poker_task3.task3 import parseTurnInput, calHandEquityHist, calPokerEmdTurn -from poker_task3.card import Card - - -def test_parse_turn_input(): - """测试TURN阶段输入解析""" - input_text = """ - As Ks - 7s 6s - 8h 9d 9c Qh - """ - - hand1, hand2, board = parseTurnInput(input_text) - - assert len(hand1) == 2 - assert len(hand2) == 2 - assert len(board) == 4 - - assert str(hand1[0]) == "As" - assert str(hand1[1]) == "Ks" - assert str(hand2[0]) == "7s" - assert str(hand2[1]) == "6s" - - print("✓ TURN输入解析测试通过") - - -def test_hand_equity_Hist(): - """测试手牌胜率分布计算""" - hole_cards = Card.parse_cards("As Ks") - board_cards = Card.parse_cards("8h 9d 9c Qh") - - Hist = calHandEquityHist(hole_cards, board_cards) - - assert len(Hist) == 30 # 胜率分布的区间数量(改进为30个区间) - assert all(isinstance(x, (int, float)) for x in Hist) - assert all(x >= 0 for x in Hist) # 所有值都应该是非负的 - assert abs(sum(Hist) - 1.0) < 1e-10 # 分布总和应该为1(标准化后) - - print("✓ 手牌胜率分布测试通过") - - -def test_poker_emd_calculation(): - """测试主要的EMD计算函数""" - input_text = """As Ks -7s 6s -8h 9d 9c Qh""" - - distance = calPokerEmdTurn(input_text) - - assert isinstance(distance, (int, float)) - assert distance >= 0 # EMD总是非负的 - - print(f"✓ 扑克牌EMD计算测试通过 (距离: {distance:.3f})") - - -def test_different_hand_strengths(): - """测试不同强度手牌的EMD""" - # 强牌 vs 弱牌 - input_text = """As Ks -2c 3d -8h 9d 9c Qh""" - - distance = calPokerEmdTurn(input_text) - - assert isinstance(distance, (int, float)) - assert distance >= 0 # 应该有非负距离 - - print(f"✓ 不同手牌强度测试通过 (距离: {distance:.3f})") - - -def test_error_handling(): - """测试无效输入的错误处理""" - # 错误的行数 - result = parseTurnInput("As Ks\n7s 6s") - assert result == (None, None, None), "应该返回(None, None, None)" - - # 手牌中牌数错误 - result = parseTurnInput("As Ks Qs\n7s 6s\n8h 9d 9c Qh") - assert result == (None, None, None), "应该返回(None, None, None)" - - # 公共牌数错误 - result = parseTurnInput("As Ks\n7s 6s\n8h 9d 9c") - assert result == (None, None, None), "应该返回(None, None, None)" - - print("✓ 错误处理测试通过") - - -if __name__ == "__main__": - test_parse_turn_input() - test_hand_equity_Hist() - test_poker_emd_calculation() - test_different_hand_strengths() - test_error_handling() - print("所有任务3测试通过! ✓") \ No newline at end of file