183 lines
4.5 KiB
Python
183 lines
4.5 KiB
Python
import sys
|
||
from pathlib import Path
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
from src.rag.loader import DocLoader
|
||
from src.rag.retriever import Retriever
|
||
from src.rag.store import QdrantVecStore
|
||
from src.rag.embeddings import embedding_client
|
||
|
||
|
||
def test_connect():
|
||
store = QdrantVecStore()
|
||
stats = store.get_stats()
|
||
|
||
print(f" Qdrant连接成功")
|
||
print(f" 集合: {stats['collection_name']}")
|
||
print(f" 文档数: {stats['total_documents']}")
|
||
print(f" 向量维度: {stats['vector_size']}")
|
||
|
||
assert stats['collection_name'] is not None
|
||
assert stats['total_documents'] >= 0
|
||
return True
|
||
|
||
|
||
def test_embed():
|
||
texts = ["测试文本1", "测试文本2"]
|
||
vectors = embedding_client.embed(texts)
|
||
|
||
print(f" txt1,txt2 embed成功")
|
||
print(f" 输入文本数: {len(texts)}")
|
||
print(f" 输出向量数: {len(vectors)}")
|
||
print(f" 向量维度: {len(vectors[0])}")
|
||
|
||
assert len(vectors) == len(texts)
|
||
assert len(vectors[0]) == 1536
|
||
return True
|
||
|
||
|
||
def test_doc_load():
|
||
|
||
loader = DocLoader(chunk_size=200, chunk_overlap=50)
|
||
|
||
test_text = """
|
||
案件编号: CASE_001
|
||
案件名称: 测试案件 - 借款合同纠纷
|
||
|
||
原告: 张三
|
||
被告: 李四
|
||
|
||
案情简介:
|
||
原告张三于2023年1月向被告李四出借人民币50万元,约定年利率10%,
|
||
借款期限为1年。到期后被告未按约定归还本金及利息。
|
||
|
||
原告诉讼请求:
|
||
1. 判令被告归还借款本金50万元
|
||
2. 支付利息5万元
|
||
3. 承担本案诉讼费用
|
||
|
||
被告答辩:
|
||
该笔款项为合伙投资款,非借款。原告提供的借条系伪造。
|
||
"""
|
||
|
||
ids = loader.loader_text(
|
||
text=test_text,
|
||
case_id="CASE_001",
|
||
source="test_complaint.txt"
|
||
)
|
||
|
||
print(f" doc入库成功")
|
||
print(f" 案件ID: CASE_001")
|
||
print(f" 文档片段数: {len(ids)}")
|
||
print(f" 片段ID示例: {ids[0][:8]}...")
|
||
|
||
assert len(ids) > 0
|
||
return True
|
||
|
||
|
||
def test_doc_search():
|
||
|
||
retriever = Retriever()
|
||
|
||
queries = [
|
||
"原告是谁?",
|
||
"借款金额是多少?",
|
||
"被告的答辩意见是什么?"
|
||
]
|
||
|
||
for query in queries:
|
||
results = retriever.retrieve(
|
||
query=query,
|
||
case_id="CASE_001",
|
||
top_k=2
|
||
)
|
||
|
||
print(f"\n 查询: {query}")
|
||
print(f" 结果数: {len(results)}")
|
||
|
||
if results:
|
||
print(f" 最高相似度: {results[0]['score']:.3f}")
|
||
print(f" 片段预览: {results[0]['text'][:50]}...")
|
||
|
||
assert len(results) > 0
|
||
assert results[0]['score'] > 0
|
||
|
||
print(f"\n 检索成功")
|
||
return True
|
||
|
||
|
||
def test_text_format():
|
||
retriever = Retriever()
|
||
|
||
results = retriever.retrieve(
|
||
query="案件的原被告信息",
|
||
case_id="CASE_001",
|
||
top_k=3
|
||
)
|
||
|
||
formatted = retriever.format_context(results)
|
||
|
||
print(f" 格式化成功")
|
||
print(f" 原始结果数: {len(results)}")
|
||
print(f" 格式化长度: {len(formatted)} 字符")
|
||
print(f"\n 格式化预览:")
|
||
print(f" {formatted[:200]}...")
|
||
|
||
assert "<documents>" in formatted
|
||
assert "</documents>" in formatted
|
||
assert len(formatted) > 0
|
||
return True
|
||
|
||
|
||
def test_cleanup():
|
||
store = QdrantVecStore()
|
||
count = store.delete_by_case("CASE_001")
|
||
|
||
print(f" 清理完成")
|
||
if count is not None:
|
||
print(f" 删除文档数: {count}")
|
||
assert count >= 0
|
||
else:
|
||
print(f" 删除操作已执行(未返回计数)")
|
||
|
||
stats = store.get_stats()
|
||
print(f" 剩余文档数: {stats['total_documents']}")
|
||
|
||
return True
|
||
|
||
|
||
def run_all_tests():
|
||
print("\n" + "="*60)
|
||
print("RAG测试")
|
||
print("="*60)
|
||
|
||
tests = [
|
||
("Qdrant 连接", test_connect),
|
||
("嵌入客户端", test_embed),
|
||
("文档入库", test_doc_load),
|
||
("文档检索", test_doc_search),
|
||
("上下文清除", test_text_format),
|
||
("清理数据", test_cleanup)
|
||
]
|
||
|
||
passed = 0
|
||
failed = 0
|
||
|
||
for name, test_func in tests:
|
||
if test_func():
|
||
passed += 1
|
||
else:
|
||
print(f"\n 测试失败: {name}")
|
||
failed += 1
|
||
|
||
print("\n" + "="*60)
|
||
print(f"结果: {passed} 通过, {failed} 失败")
|
||
print("="*60)
|
||
|
||
return failed == 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = run_all_tests()
|
||
sys.exit(0 if success else 1)
|