Files
aiqs/tests/test_rag.py
2025-11-28 15:06:54 +08:00

183 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.rag.loader import DocLoader
from src.rag.retriever import Retriever
from src.rag.store import QdrantVecStore
from src.rag.embeddings import embedding_client
def test_connect():
store = QdrantVecStore()
stats = store.get_stats()
print(f" Qdrant连接成功")
print(f" 集合: {stats['collection_name']}")
print(f" 文档数: {stats['total_documents']}")
print(f" 向量维度: {stats['vector_size']}")
assert stats['collection_name'] is not None
assert stats['total_documents'] >= 0
return True
def test_embed():
texts = ["测试文本1", "测试文本2"]
vectors = embedding_client.embed(texts)
print(f" txt1,txt2 embed成功")
print(f" 输入文本数: {len(texts)}")
print(f" 输出向量数: {len(vectors)}")
print(f" 向量维度: {len(vectors[0])}")
assert len(vectors) == len(texts)
assert len(vectors[0]) == 1536
return True
def test_doc_load():
loader = DocLoader(chunk_size=200, chunk_overlap=50)
test_text = """
案件编号: CASE_001
案件名称: 测试案件 - 借款合同纠纷
原告: 张三
被告: 李四
案情简介:
原告张三于2023年1月向被告李四出借人民币50万元约定年利率10%
借款期限为1年。到期后被告未按约定归还本金及利息。
原告诉讼请求:
1. 判令被告归还借款本金50万元
2. 支付利息5万元
3. 承担本案诉讼费用
被告答辩:
该笔款项为合伙投资款,非借款。原告提供的借条系伪造。
"""
ids = loader.loader_text(
text=test_text,
case_id="CASE_001",
source="test_complaint.txt"
)
print(f" doc入库成功")
print(f" 案件ID: CASE_001")
print(f" 文档片段数: {len(ids)}")
print(f" 片段ID示例: {ids[0][:8]}...")
assert len(ids) > 0
return True
def test_doc_search():
retriever = Retriever()
queries = [
"原告是谁?",
"借款金额是多少?",
"被告的答辩意见是什么?"
]
for query in queries:
results = retriever.retrieve(
query=query,
case_id="CASE_001",
top_k=2
)
print(f"\n 查询: {query}")
print(f" 结果数: {len(results)}")
if results:
print(f" 最高相似度: {results[0]['score']:.3f}")
print(f" 片段预览: {results[0]['text'][:50]}...")
assert len(results) > 0
assert results[0]['score'] > 0
print(f"\n 检索成功")
return True
def test_text_format():
retriever = Retriever()
results = retriever.retrieve(
query="案件的原被告信息",
case_id="CASE_001",
top_k=3
)
formatted = retriever.format_context(results)
print(f" 格式化成功")
print(f" 原始结果数: {len(results)}")
print(f" 格式化长度: {len(formatted)} 字符")
print(f"\n 格式化预览:")
print(f" {formatted[:200]}...")
assert "<documents>" in formatted
assert "</documents>" in formatted
assert len(formatted) > 0
return True
def test_cleanup():
store = QdrantVecStore()
count = store.delete_by_case("CASE_001")
print(f" 清理完成")
if count is not None:
print(f" 删除文档数: {count}")
assert count >= 0
else:
print(f" 删除操作已执行(未返回计数)")
stats = store.get_stats()
print(f" 剩余文档数: {stats['total_documents']}")
return True
def run_all_tests():
print("\n" + "="*60)
print("RAG测试")
print("="*60)
tests = [
("Qdrant 连接", test_connect),
("嵌入客户端", test_embed),
("文档入库", test_doc_load),
("文档检索", test_doc_search),
("上下文清除", test_text_format),
("清理数据", test_cleanup)
]
passed = 0
failed = 0
for name, test_func in tests:
if test_func():
passed += 1
else:
print(f"\n 测试失败: {name}")
failed += 1
print("\n" + "="*60)
print(f"结果: {passed} 通过, {failed} 失败")
print("="*60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)