aiqs
This commit is contained in:
202
tests/test_mcp.py
Normal file
202
tests/test_mcp.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import json
|
||||
from src.mcp.mem import ConvoStore
|
||||
from src.tools.case_tools import CaseTools
|
||||
|
||||
|
||||
def test_conv_store():
|
||||
store = ConvoStore()
|
||||
|
||||
test_session = "CASE_SESSION_001"
|
||||
|
||||
store.add_conv(
|
||||
session_id=test_session,
|
||||
user_msg="这个案件的原告是谁?",
|
||||
ai_msg="根据案件文档,原告是张三。",
|
||||
metadata={"case_id": "CASE_001"}
|
||||
)
|
||||
|
||||
store.add_conv(
|
||||
session_id=test_session,
|
||||
user_msg="借款金额是多少?",
|
||||
ai_msg="借款本金为50万元,约定年利率10%。",
|
||||
metadata={"case_id": "CASE_001"}
|
||||
)
|
||||
|
||||
store.add_conv(
|
||||
session_id=test_session,
|
||||
user_msg="被告的答辩意见是什么?",
|
||||
ai_msg="被告称该笔款项为合伙投资款,非借款,并称借条系伪造。",
|
||||
metadata={"case_id": "CASE_001"}
|
||||
)
|
||||
|
||||
print(f" 对话存储成功")
|
||||
print(f" 会话ID: {test_session}")
|
||||
print(f" 对话轮数: 3")
|
||||
return True
|
||||
|
||||
|
||||
def test_conv_search():
|
||||
|
||||
store = ConvoStore()
|
||||
test_session = "CASE_SESSION_001"
|
||||
|
||||
queries = [
|
||||
"原告信息",
|
||||
"借了多少钱",
|
||||
"被告怎么说"
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
results = store.search_conv(
|
||||
session_id=test_session,
|
||||
query=query,
|
||||
limit=2
|
||||
)
|
||||
|
||||
print(f"\n 查询: {query}")
|
||||
print(f" 结果数: {len(results)}")
|
||||
|
||||
if results:
|
||||
print(f" 最高相关度: {results[0]['score']:.3f}")
|
||||
print(f" 匹配对话: Q: {results[0]['user'][:30]}...")
|
||||
|
||||
assert len(results) > 0
|
||||
|
||||
print(f"\n 检索成功")
|
||||
return True
|
||||
|
||||
|
||||
def test_conv_all():
|
||||
|
||||
store = ConvoStore()
|
||||
test_session = "CASE_SESSION_001"
|
||||
|
||||
all_convs = store.get_all_conv(
|
||||
session_id=test_session,
|
||||
limit=100
|
||||
)
|
||||
|
||||
print(f" 会话获取成功")
|
||||
print(f" 会话ID: {test_session}")
|
||||
print(f" 对话总数: {len(all_convs)}")
|
||||
|
||||
if all_convs:
|
||||
print(f"\n 最近对话:")
|
||||
for i, conv in enumerate(all_convs[:2], 1):
|
||||
print(f" {i}. Q: {conv['user_msg'][:40]}...")
|
||||
print(f" A: {conv['ai_msg'][:40]}...")
|
||||
|
||||
assert len(all_convs) >= 3
|
||||
return True
|
||||
|
||||
def test_case_tools_metadata():
|
||||
|
||||
tools = CaseTools()
|
||||
result = tools.get_case_metadata("CASE_001")
|
||||
|
||||
print(f" 查询完成")
|
||||
print(f" 案件ID: CASE_001")
|
||||
|
||||
if "error" in result:
|
||||
print(f" 状态: {result['error']}")
|
||||
print(f" 说明: {result.get('description', '')}")
|
||||
else:
|
||||
print(f" 文档数: {result.get('doc_cnt', 0)}")
|
||||
print(f" 片段数: {result.get('chunk_cnt', 0)}")
|
||||
print(f" 状态: {result.get('status', '')}")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
return True
|
||||
|
||||
|
||||
def test_build_rag_prompt():
|
||||
tools = CaseTools()
|
||||
|
||||
mock_context = """<documents>
|
||||
<document index="1" source="complaint.txt">
|
||||
原告: 张三
|
||||
被告: 李四
|
||||
借款金额: 50万元
|
||||
</document>
|
||||
<document index="2" source="defense.txt">
|
||||
被告答辩: 该笔款项为合伙投资款
|
||||
</document>
|
||||
</documents>"""
|
||||
|
||||
result = tools.build_rag_prompt(
|
||||
query="原告和被告分别是谁?",
|
||||
context=mock_context,
|
||||
session_id="CASE_SESSION_001",
|
||||
max_history=2
|
||||
)
|
||||
|
||||
print(f" prompt构建成功")
|
||||
print(f" 消息总数: {result['prompt_stats']['total_messages']}")
|
||||
print(f" 包含上下文: {result['prompt_stats']['has_context']}")
|
||||
print(f" 包含历史: {result['prompt_stats']['has_history']}")
|
||||
print(f" 历史来源: {result['prompt_stats']['history_source']}")
|
||||
print(f" 历史轮数: {result['prompt_stats']['history_length']}")
|
||||
|
||||
assert len(result['messages']) > 0
|
||||
assert result['prompt_stats']['has_context']
|
||||
return True
|
||||
|
||||
def test_cleanup():
|
||||
store = ConvoStore()
|
||||
test_session = "CASE_SESSION_001"
|
||||
|
||||
all_convs = store.get_all_conv(session_id=test_session, limit=1000)
|
||||
print(f" 清理前对话数: {len(all_convs)}")
|
||||
|
||||
count = store.delete_by_session(test_session)
|
||||
|
||||
print(f" 清理完成")
|
||||
print(f" 测试会话: {test_session}")
|
||||
print(f" 删除对话数: {count}")
|
||||
|
||||
remaining = store.get_all_conv(session_id=test_session, limit=1000)
|
||||
print(f" 清理后对话数: {len(remaining)}")
|
||||
|
||||
assert len(remaining) == 0
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
print("\n" + "="*60)
|
||||
print("MCP 测试")
|
||||
print("="*60)
|
||||
|
||||
tests = [
|
||||
("对话存储", test_conv_store),
|
||||
("对话检索", test_conv_search),
|
||||
("获取全部对话", test_conv_all),
|
||||
("获取案件元数据", test_case_tools_metadata),
|
||||
("构建 RAG Prompt", test_build_rag_prompt),
|
||||
("清理测试数据", test_cleanup)
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for name, test_func in tests:
|
||||
if test_func():
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
print(f"\n 测试失败: {name}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print(f"结果: {passed} 通过, {failed} 失败")
|
||||
print("="*60)
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user