Files
aiqs/tests/test_mcp.py
2025-11-28 15:06:54 +08:00

203 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import json
from src.mcp.mem import ConvoStore
from src.tools.case_tools import CaseTools
def test_conv_store():
store = ConvoStore()
test_session = "CASE_SESSION_001"
store.add_conv(
session_id=test_session,
user_msg="这个案件的原告是谁?",
ai_msg="根据案件文档,原告是张三。",
metadata={"case_id": "CASE_001"}
)
store.add_conv(
session_id=test_session,
user_msg="借款金额是多少?",
ai_msg="借款本金为50万元约定年利率10%",
metadata={"case_id": "CASE_001"}
)
store.add_conv(
session_id=test_session,
user_msg="被告的答辩意见是什么?",
ai_msg="被告称该笔款项为合伙投资款,非借款,并称借条系伪造。",
metadata={"case_id": "CASE_001"}
)
print(f" 对话存储成功")
print(f" 会话ID: {test_session}")
print(f" 对话轮数: 3")
return True
def test_conv_search():
store = ConvoStore()
test_session = "CASE_SESSION_001"
queries = [
"原告信息",
"借了多少钱",
"被告怎么说"
]
for query in queries:
results = store.search_conv(
session_id=test_session,
query=query,
limit=2
)
print(f"\n 查询: {query}")
print(f" 结果数: {len(results)}")
if results:
print(f" 最高相关度: {results[0]['score']:.3f}")
print(f" 匹配对话: Q: {results[0]['user'][:30]}...")
assert len(results) > 0
print(f"\n 检索成功")
return True
def test_conv_all():
store = ConvoStore()
test_session = "CASE_SESSION_001"
all_convs = store.get_all_conv(
session_id=test_session,
limit=100
)
print(f" 会话获取成功")
print(f" 会话ID: {test_session}")
print(f" 对话总数: {len(all_convs)}")
if all_convs:
print(f"\n 最近对话:")
for i, conv in enumerate(all_convs[:2], 1):
print(f" {i}. Q: {conv['user_msg'][:40]}...")
print(f" A: {conv['ai_msg'][:40]}...")
assert len(all_convs) >= 3
return True
def test_case_tools_metadata():
tools = CaseTools()
result = tools.get_case_metadata("CASE_001")
print(f" 查询完成")
print(f" 案件ID: CASE_001")
if "error" in result:
print(f" 状态: {result['error']}")
print(f" 说明: {result.get('description', '')}")
else:
print(f" 文档数: {result.get('doc_cnt', 0)}")
print(f" 片段数: {result.get('chunk_cnt', 0)}")
print(f" 状态: {result.get('status', '')}")
assert isinstance(result, dict)
return True
def test_build_rag_prompt():
tools = CaseTools()
mock_context = """<documents>
<document index="1" source="complaint.txt">
原告: 张三
被告: 李四
借款金额: 50万元
</document>
<document index="2" source="defense.txt">
被告答辩: 该笔款项为合伙投资款
</document>
</documents>"""
result = tools.build_rag_prompt(
query="原告和被告分别是谁?",
context=mock_context,
session_id="CASE_SESSION_001",
max_history=2
)
print(f" prompt构建成功")
print(f" 消息总数: {result['prompt_stats']['total_messages']}")
print(f" 包含上下文: {result['prompt_stats']['has_context']}")
print(f" 包含历史: {result['prompt_stats']['has_history']}")
print(f" 历史来源: {result['prompt_stats']['history_source']}")
print(f" 历史轮数: {result['prompt_stats']['history_length']}")
assert len(result['messages']) > 0
assert result['prompt_stats']['has_context']
return True
def test_cleanup():
store = ConvoStore()
test_session = "CASE_SESSION_001"
all_convs = store.get_all_conv(session_id=test_session, limit=1000)
print(f" 清理前对话数: {len(all_convs)}")
count = store.delete_by_session(test_session)
print(f" 清理完成")
print(f" 测试会话: {test_session}")
print(f" 删除对话数: {count}")
remaining = store.get_all_conv(session_id=test_session, limit=1000)
print(f" 清理后对话数: {len(remaining)}")
assert len(remaining) == 0
return True
def run_all_tests():
print("\n" + "="*60)
print("MCP 测试")
print("="*60)
tests = [
("对话存储", test_conv_store),
("对话检索", test_conv_search),
("获取全部对话", test_conv_all),
("获取案件元数据", test_case_tools_metadata),
("构建 RAG Prompt", test_build_rag_prompt),
("清理测试数据", test_cleanup)
]
passed = 0
failed = 0
for name, test_func in tests:
if test_func():
passed += 1
else:
failed += 1
print(f"\n 测试失败: {name}")
print("\n" + "="*60)
print(f"结果: {passed} 通过, {failed} 失败")
print("="*60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)