This commit is contained in:
2025-11-28 15:06:54 +08:00
commit 0bb8b4e13b
21 changed files with 1744 additions and 0 deletions

2
tests/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""tests package"""

202
tests/test_mcp.py Normal file
View File

@@ -0,0 +1,202 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import json
from src.mcp.mem import ConvoStore
from src.tools.case_tools import CaseTools
def test_conv_store():
store = ConvoStore()
test_session = "CASE_SESSION_001"
store.add_conv(
session_id=test_session,
user_msg="这个案件的原告是谁?",
ai_msg="根据案件文档,原告是张三。",
metadata={"case_id": "CASE_001"}
)
store.add_conv(
session_id=test_session,
user_msg="借款金额是多少?",
ai_msg="借款本金为50万元约定年利率10%",
metadata={"case_id": "CASE_001"}
)
store.add_conv(
session_id=test_session,
user_msg="被告的答辩意见是什么?",
ai_msg="被告称该笔款项为合伙投资款,非借款,并称借条系伪造。",
metadata={"case_id": "CASE_001"}
)
print(f" 对话存储成功")
print(f" 会话ID: {test_session}")
print(f" 对话轮数: 3")
return True
def test_conv_search():
store = ConvoStore()
test_session = "CASE_SESSION_001"
queries = [
"原告信息",
"借了多少钱",
"被告怎么说"
]
for query in queries:
results = store.search_conv(
session_id=test_session,
query=query,
limit=2
)
print(f"\n 查询: {query}")
print(f" 结果数: {len(results)}")
if results:
print(f" 最高相关度: {results[0]['score']:.3f}")
print(f" 匹配对话: Q: {results[0]['user'][:30]}...")
assert len(results) > 0
print(f"\n 检索成功")
return True
def test_conv_all():
store = ConvoStore()
test_session = "CASE_SESSION_001"
all_convs = store.get_all_conv(
session_id=test_session,
limit=100
)
print(f" 会话获取成功")
print(f" 会话ID: {test_session}")
print(f" 对话总数: {len(all_convs)}")
if all_convs:
print(f"\n 最近对话:")
for i, conv in enumerate(all_convs[:2], 1):
print(f" {i}. Q: {conv['user_msg'][:40]}...")
print(f" A: {conv['ai_msg'][:40]}...")
assert len(all_convs) >= 3
return True
def test_case_tools_metadata():
tools = CaseTools()
result = tools.get_case_metadata("CASE_001")
print(f" 查询完成")
print(f" 案件ID: CASE_001")
if "error" in result:
print(f" 状态: {result['error']}")
print(f" 说明: {result.get('description', '')}")
else:
print(f" 文档数: {result.get('doc_cnt', 0)}")
print(f" 片段数: {result.get('chunk_cnt', 0)}")
print(f" 状态: {result.get('status', '')}")
assert isinstance(result, dict)
return True
def test_build_rag_prompt():
tools = CaseTools()
mock_context = """<documents>
<document index="1" source="complaint.txt">
原告: 张三
被告: 李四
借款金额: 50万元
</document>
<document index="2" source="defense.txt">
被告答辩: 该笔款项为合伙投资款
</document>
</documents>"""
result = tools.build_rag_prompt(
query="原告和被告分别是谁?",
context=mock_context,
session_id="CASE_SESSION_001",
max_history=2
)
print(f" prompt构建成功")
print(f" 消息总数: {result['prompt_stats']['total_messages']}")
print(f" 包含上下文: {result['prompt_stats']['has_context']}")
print(f" 包含历史: {result['prompt_stats']['has_history']}")
print(f" 历史来源: {result['prompt_stats']['history_source']}")
print(f" 历史轮数: {result['prompt_stats']['history_length']}")
assert len(result['messages']) > 0
assert result['prompt_stats']['has_context']
return True
def test_cleanup():
store = ConvoStore()
test_session = "CASE_SESSION_001"
all_convs = store.get_all_conv(session_id=test_session, limit=1000)
print(f" 清理前对话数: {len(all_convs)}")
count = store.delete_by_session(test_session)
print(f" 清理完成")
print(f" 测试会话: {test_session}")
print(f" 删除对话数: {count}")
remaining = store.get_all_conv(session_id=test_session, limit=1000)
print(f" 清理后对话数: {len(remaining)}")
assert len(remaining) == 0
return True
def run_all_tests():
print("\n" + "="*60)
print("MCP 测试")
print("="*60)
tests = [
("对话存储", test_conv_store),
("对话检索", test_conv_search),
("获取全部对话", test_conv_all),
("获取案件元数据", test_case_tools_metadata),
("构建 RAG Prompt", test_build_rag_prompt),
("清理测试数据", test_cleanup)
]
passed = 0
failed = 0
for name, test_func in tests:
if test_func():
passed += 1
else:
failed += 1
print(f"\n 测试失败: {name}")
print("\n" + "="*60)
print(f"结果: {passed} 通过, {failed} 失败")
print("="*60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)

182
tests/test_rag.py Normal file
View File

@@ -0,0 +1,182 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.rag.loader import DocLoader
from src.rag.retriever import Retriever
from src.rag.store import QdrantVecStore
from src.rag.embeddings import embedding_client
def test_connect():
store = QdrantVecStore()
stats = store.get_stats()
print(f" Qdrant连接成功")
print(f" 集合: {stats['collection_name']}")
print(f" 文档数: {stats['total_documents']}")
print(f" 向量维度: {stats['vector_size']}")
assert stats['collection_name'] is not None
assert stats['total_documents'] >= 0
return True
def test_embed():
texts = ["测试文本1", "测试文本2"]
vectors = embedding_client.embed(texts)
print(f" txt1,txt2 embed成功")
print(f" 输入文本数: {len(texts)}")
print(f" 输出向量数: {len(vectors)}")
print(f" 向量维度: {len(vectors[0])}")
assert len(vectors) == len(texts)
assert len(vectors[0]) == 1536
return True
def test_doc_load():
loader = DocLoader(chunk_size=200, chunk_overlap=50)
test_text = """
案件编号: CASE_001
案件名称: 测试案件 - 借款合同纠纷
原告: 张三
被告: 李四
案情简介:
原告张三于2023年1月向被告李四出借人民币50万元约定年利率10%
借款期限为1年。到期后被告未按约定归还本金及利息。
原告诉讼请求:
1. 判令被告归还借款本金50万元
2. 支付利息5万元
3. 承担本案诉讼费用
被告答辩:
该笔款项为合伙投资款,非借款。原告提供的借条系伪造。
"""
ids = loader.loader_text(
text=test_text,
case_id="CASE_001",
source="test_complaint.txt"
)
print(f" doc入库成功")
print(f" 案件ID: CASE_001")
print(f" 文档片段数: {len(ids)}")
print(f" 片段ID示例: {ids[0][:8]}...")
assert len(ids) > 0
return True
def test_doc_search():
retriever = Retriever()
queries = [
"原告是谁?",
"借款金额是多少?",
"被告的答辩意见是什么?"
]
for query in queries:
results = retriever.retrieve(
query=query,
case_id="CASE_001",
top_k=2
)
print(f"\n 查询: {query}")
print(f" 结果数: {len(results)}")
if results:
print(f" 最高相似度: {results[0]['score']:.3f}")
print(f" 片段预览: {results[0]['text'][:50]}...")
assert len(results) > 0
assert results[0]['score'] > 0
print(f"\n 检索成功")
return True
def test_text_format():
retriever = Retriever()
results = retriever.retrieve(
query="案件的原被告信息",
case_id="CASE_001",
top_k=3
)
formatted = retriever.format_context(results)
print(f" 格式化成功")
print(f" 原始结果数: {len(results)}")
print(f" 格式化长度: {len(formatted)} 字符")
print(f"\n 格式化预览:")
print(f" {formatted[:200]}...")
assert "<documents>" in formatted
assert "</documents>" in formatted
assert len(formatted) > 0
return True
def test_cleanup():
store = QdrantVecStore()
count = store.delete_by_case("CASE_001")
print(f" 清理完成")
if count is not None:
print(f" 删除文档数: {count}")
assert count >= 0
else:
print(f" 删除操作已执行(未返回计数)")
stats = store.get_stats()
print(f" 剩余文档数: {stats['total_documents']}")
return True
def run_all_tests():
print("\n" + "="*60)
print("RAG测试")
print("="*60)
tests = [
("Qdrant 连接", test_connect),
("嵌入客户端", test_embed),
("文档入库", test_doc_load),
("文档检索", test_doc_search),
("上下文清除", test_text_format),
("清理数据", test_cleanup)
]
passed = 0
failed = 0
for name, test_func in tests:
if test_func():
passed += 1
else:
print(f"\n 测试失败: {name}")
failed += 1
print("\n" + "="*60)
print(f"结果: {passed} 通过, {failed} 失败")
print("="*60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)