aiqs
This commit is contained in:
25
.gitignore
vendored
Normal file
25
.gitignore
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
venv/
|
||||||
|
.venv/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
*.log
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
*.egg-info/
|
||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
data/
|
||||||
|
.vscode/
|
||||||
|
.claude/
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
aiqs/
|
||||||
|
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.13
|
||||||
6
main.py
Normal file
6
main.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from aiqs!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
12
pyproject.toml
Normal file
12
pyproject.toml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
[project]
|
||||||
|
name = "aiqs"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"openai==2.7.1",
|
||||||
|
"python-dotenv>=1.2.1",
|
||||||
|
"pydantic>=2.12.4",
|
||||||
|
"pydantic-settings>=2.11.0",
|
||||||
|
]
|
||||||
9
requirement.txt
Normal file
9
requirement.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
openai>=1.0.0
|
||||||
|
pydantic>=2.0.0
|
||||||
|
pydantic-settings>=2.0.0
|
||||||
|
httpx>=0.24.0
|
||||||
|
|
||||||
|
qdrant-client>=1.7.0
|
||||||
|
tiktoken>=0.5.0
|
||||||
|
|
||||||
|
pillow>=10.0.0
|
||||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
33
src/config.py
Normal file
33
src/config.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
openai_api_key: str = Field(default="", description="key")
|
||||||
|
openai_base_url: str = Field(
|
||||||
|
default="https://api.openai.com/v1",
|
||||||
|
description="url",
|
||||||
|
)
|
||||||
|
chat_model: str = Field(
|
||||||
|
default="gpt-4o-mini"
|
||||||
|
)
|
||||||
|
embedding_model: str = Field(
|
||||||
|
default="text-embedding-3-small"
|
||||||
|
)
|
||||||
|
qdrant_host: str = Field(default="localhost")
|
||||||
|
qdrant_port: int = Field(default=6333)
|
||||||
|
qdrant_collection: str = Field(
|
||||||
|
default="legal_documents"
|
||||||
|
)
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=PROJECT_ROOT / ".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
334
src/mcp/case_server.py
Normal file
334
src/mcp/case_server.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from mcp.server import Server
|
||||||
|
from mcp.server.stdio import stdio_server
|
||||||
|
from mcp import types
|
||||||
|
import json
|
||||||
|
|
||||||
|
from src.rag.retriever import Retriever
|
||||||
|
from src.tools.case_tools import tools
|
||||||
|
from src.mcp.mem import ConvoStore
|
||||||
|
|
||||||
|
app = Server("legal-case-server")
|
||||||
|
retriever = Retriever()
|
||||||
|
conv = ConvoStore()
|
||||||
|
|
||||||
|
@app.list_tools()
|
||||||
|
async def list_tools() -> list[types.Tool]:
|
||||||
|
return [
|
||||||
|
types.Tool(
|
||||||
|
name="search_case_documents",
|
||||||
|
description="在法律案件文档中检索相关信息。用于回答关于案件事实、当事人信息、诉讼请求、判决结果等问题。",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"case_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "案件ID,如 'case_001'"
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "检索问题,如 '原告的诉讼请求是什么?'"
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "返回结果数量,默认3",
|
||||||
|
"default": 3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["case_id", "query"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
types.Tool(
|
||||||
|
name="get_case_metadata",
|
||||||
|
description="获取案件的基本信息",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"case_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "案件ID"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["case_id"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
types.Tool(
|
||||||
|
name="build_rag_prompt",
|
||||||
|
description="构建标准化的 RAG Prompt,整合检索上下文和对话历史。返回可直接用于 LLM 的 messages 数组。",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "用户问题"
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "检索到的文档上下文(XML 格式)"
|
||||||
|
},
|
||||||
|
"conversation_history": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "对话历史(可选)",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"role": {"type": "string"},
|
||||||
|
"content": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"system_prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "自定义系统提示词(可选)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query", "context"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
types.Tool(
|
||||||
|
name="gen_timeline",
|
||||||
|
description="从文档中生成案件时间线,自动提取并排序时间信息",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"documents": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "检索结果文档列表",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {"type": "string"},
|
||||||
|
"metadata": {"type": "object"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["documents"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
types.Tool(
|
||||||
|
name="store_conservation",
|
||||||
|
description="存储对话到会话记录中",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"session_id": {"type": "string", "description": "会话ID"},
|
||||||
|
"user_message": {"type": "string", "description": "用户消息"},
|
||||||
|
"ai_message": {"type": "string", "description": "AI回复"},
|
||||||
|
"metadata": {"type": "object", "description": "额外元数据"}
|
||||||
|
},
|
||||||
|
"required": ["session_id", "user_message", "ai_message"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
types.Tool(
|
||||||
|
name="search_conversation",
|
||||||
|
description="检索会话中的相关对话记录",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"session_id": {"type": "string"},
|
||||||
|
"query": {"type": "string", "description": "搜索查询"},
|
||||||
|
"limit": {"type": "integer","default": 5}
|
||||||
|
},
|
||||||
|
"required": ["session_id", "query"]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
types.Tool(
|
||||||
|
name="get_all_conversation",
|
||||||
|
description="获取所有会话记录",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"session_id": {"type": "string", "description": "会话ID"},
|
||||||
|
"limit": {"type": "integer","default": 100}
|
||||||
|
},
|
||||||
|
"required": ["session_id"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@app.list_resources()
|
||||||
|
async def list_resources() -> list[types.Resource]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@app.list_prompts()
|
||||||
|
async def list_prompts() -> list[types.Prompt]:
|
||||||
|
return [
|
||||||
|
types.Prompt(
|
||||||
|
name="analyze_case",
|
||||||
|
description="分析法律案件的标准流程",
|
||||||
|
arguments=[
|
||||||
|
types.PromptArgument(
|
||||||
|
name="case_id",
|
||||||
|
description="案件ID",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@app.call_tool()
|
||||||
|
async def call_tool(name: str, arguments: dict) -> list[types.TextContent]:
|
||||||
|
if name == "search_case_documents":
|
||||||
|
case_id = arguments["case_id"]
|
||||||
|
query = arguments["query"]
|
||||||
|
top_k = arguments.get("top_k", 3)
|
||||||
|
|
||||||
|
results = retriever.retrieve(
|
||||||
|
query=query,
|
||||||
|
case_id=case_id,
|
||||||
|
top_k=top_k
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
response = {
|
||||||
|
"status": "no_results",
|
||||||
|
"message": f"在案件 {case_id} 中未找到与 '{query}' 相关的文档"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
context = retriever.format_context(results)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"status": "success",
|
||||||
|
"case_id": case_id,
|
||||||
|
"query": query,
|
||||||
|
"results_count": len(results),
|
||||||
|
"context": context,
|
||||||
|
"documents": [
|
||||||
|
{
|
||||||
|
"index": i + 1,
|
||||||
|
"text": r["text"][:300],
|
||||||
|
"score": round(r["score"], 3),
|
||||||
|
"source": r["metadata"].get("source", "unknown")
|
||||||
|
}
|
||||||
|
for i, r in enumerate(results)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
return [types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps(response, ensure_ascii=False, indent=2)
|
||||||
|
)]
|
||||||
|
|
||||||
|
elif name == "get_case_metadata":
|
||||||
|
result = tools.get_case_metadata(**arguments)
|
||||||
|
return [types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
)]
|
||||||
|
|
||||||
|
elif name == "build_rag_prompt":
|
||||||
|
result = tools.build_rag_prompt(**arguments)
|
||||||
|
return [types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
)]
|
||||||
|
elif name == "gen_timeline":
|
||||||
|
result = tools.gen_timeline(**arguments)
|
||||||
|
return [types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
)]
|
||||||
|
elif name == "store_conservation":
|
||||||
|
session_id = arguments.get("session_id")
|
||||||
|
user_msg = arguments.get("user_msg")
|
||||||
|
ai_msg = arguments.get("ai_msg")
|
||||||
|
metadata = arguments.get("metadata", {})
|
||||||
|
|
||||||
|
conv.add_conv(
|
||||||
|
session_id=session_id,
|
||||||
|
user_msg=user_msg,
|
||||||
|
ai_msg=ai_msg,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
return [types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"对话以存储到会话{session_id}中。"
|
||||||
|
)]
|
||||||
|
|
||||||
|
elif name == "search_conversation":
|
||||||
|
session_id = arguments.get("session_id")
|
||||||
|
query = arguments.get("query")
|
||||||
|
limit = arguments.get("limit", 5)
|
||||||
|
|
||||||
|
results = conv.search_conv(
|
||||||
|
session_id=session_id,
|
||||||
|
query=query,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = "\n\n".join([
|
||||||
|
f"**Q**: {r['user']}\n**A**: {r['assistant']}\n(相关度: {r['score']:.2f})"
|
||||||
|
for r in results
|
||||||
|
])
|
||||||
|
|
||||||
|
return [types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=formatted or "未找到相关对话记录。"
|
||||||
|
)]
|
||||||
|
|
||||||
|
elif name == "get_all_conversation":
|
||||||
|
session_id = arguments.get("session_id")
|
||||||
|
limit = arguments.get("limit", 100)
|
||||||
|
|
||||||
|
results = conv.get_all_conv(
|
||||||
|
session_id=session_id,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = "\n\n".join([
|
||||||
|
f"**Q**: {r['user']}\n**A**: {r['assistant']}"
|
||||||
|
for r in results
|
||||||
|
])
|
||||||
|
|
||||||
|
return [types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=formatted or "该会话暂无对话记录。"
|
||||||
|
)]
|
||||||
|
|
||||||
|
|
||||||
|
raise ValueError(f"未知工具: {name}")
|
||||||
|
|
||||||
|
@app.get_prompt()
|
||||||
|
async def get_prompt(name: str, arguments: dict) -> types.GetPromptResult:
|
||||||
|
if name == "analyze_case":
|
||||||
|
case_id = arguments["case_id"]
|
||||||
|
|
||||||
|
return types.GetPromptResult(
|
||||||
|
description=f"分析案件 {case_id}",
|
||||||
|
messages=[
|
||||||
|
types.PromptMessage(
|
||||||
|
role="user",
|
||||||
|
content=types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"请分析案件 {case_id},包括:1. 案情摘要 2. 争议焦点 3. 法律依据"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.read_resource()
|
||||||
|
async def read_resource(uri: str) -> str:
|
||||||
|
mapping = {
|
||||||
|
"case://case_001/summary": Path("d:/workspace/aiqs/data/case_001_summary.txt")
|
||||||
|
}
|
||||||
|
path = mapping.get(uri)
|
||||||
|
if path and path.exists():
|
||||||
|
return path.read_text(encoding="utf-8")
|
||||||
|
return f"未列出的资源或不存在: {uri}"
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with stdio_server() as (read_stream, write_stream):
|
||||||
|
await app.run(
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
app.create_initialization_options()
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
239
src/mcp/mem.py
Normal file
239
src/mcp/mem.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import Distance, VectorParams, PointStruct
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from src.rag.embeddings import embedding_client
|
||||||
|
|
||||||
|
|
||||||
|
class ConvoStore:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str = "conversations",
|
||||||
|
qdrant_url: str = "http://localhost:6333"
|
||||||
|
):
|
||||||
|
self.client = QdrantClient(url=qdrant_url)
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
if not collection_name in [c.name for c in self.client.get_collections().collections]:
|
||||||
|
self.client.create_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=1536,
|
||||||
|
distance=Distance.COSINE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.client.get_collection(collection_name)
|
||||||
|
|
||||||
|
def add_conv(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
user_msg: str,
|
||||||
|
ai_msg: str,
|
||||||
|
metadata: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
if not session_id or not user_msg or not ai_msg:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 用户问+ai答 为一条记录(还是分开?)
|
||||||
|
conv_text = f"用户: {user_msg}\nAI: {ai_msg}"
|
||||||
|
vector = embedding_client.embed([conv_text])[0]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_msg": user_msg,
|
||||||
|
"ai_msg": ai_msg,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
if metadata:
|
||||||
|
payload.update(metadata)
|
||||||
|
|
||||||
|
point_id = str(uuid.uuid4())
|
||||||
|
self.client.upsert(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=point_id,
|
||||||
|
vector=vector,
|
||||||
|
payload=payload
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return point_id
|
||||||
|
|
||||||
|
def search_conv(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5
|
||||||
|
) -> List[Dict]:
|
||||||
|
if not session_id or not query:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_vector = embedding_client.embed([query])[0]
|
||||||
|
|
||||||
|
results = self.client.query_points(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query=query_vector,
|
||||||
|
query_filter={
|
||||||
|
"must": [
|
||||||
|
{"key": "session_id", "match": {"value": session_id}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
limit=limit,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
points = results.points if hasattr(results, "points") else []
|
||||||
|
if not points:
|
||||||
|
return []
|
||||||
|
convs = []
|
||||||
|
for point in points:
|
||||||
|
payload = point.payload if hasattr(point, "payload") else {}
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
convs.append({
|
||||||
|
"user": payload.get("user_msg",""),
|
||||||
|
"ai": payload.get("ai_msg",""),
|
||||||
|
"timestamp": payload.get("timestamp",""),
|
||||||
|
"score": point.score if hasattr(point, "score") else 0.0
|
||||||
|
})
|
||||||
|
|
||||||
|
return convs
|
||||||
|
|
||||||
|
def get_all_conv(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[Dict]:
|
||||||
|
if not session_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = self.client.scroll(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
scroll_filter={
|
||||||
|
"must": [
|
||||||
|
{"key": "session_id",
|
||||||
|
"match": {"value": session_id}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
limit=limit,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False
|
||||||
|
)
|
||||||
|
points = results[0] if results and len(results) > 0 else []
|
||||||
|
if not points:
|
||||||
|
return []
|
||||||
|
|
||||||
|
convs = []
|
||||||
|
for point in points:
|
||||||
|
payload = point.payload if hasattr(point, "payload") else {}
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
convs.append({
|
||||||
|
"user_msg": payload.get("user_msg",""),
|
||||||
|
"ai_msg": payload.get("ai_msg",""),
|
||||||
|
"timestamp": payload.get("timestamp","")
|
||||||
|
})
|
||||||
|
convs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||||
|
return convs
|
||||||
|
|
||||||
|
def delete_by_session(self, session_id: str) -> int:
|
||||||
|
if not session_id:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
results = self.client.scroll(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
scroll_filter={
|
||||||
|
"must": [
|
||||||
|
{"key": "session_id",
|
||||||
|
"match": {"value": session_id}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
limit=10000,
|
||||||
|
with_payload=False,
|
||||||
|
with_vectors=False
|
||||||
|
)
|
||||||
|
|
||||||
|
points = results[0] if results and len(results) > 0 else []
|
||||||
|
|
||||||
|
if not points:
|
||||||
|
return 0
|
||||||
|
point_ids = [point.id for point in points]
|
||||||
|
|
||||||
|
self.client.delete(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points_selector=point_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(point_ids)
|
||||||
|
|
||||||
|
# # def qdrant_example():
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# store = ConvoStore()
|
||||||
|
|
||||||
|
# store.add_conv(
|
||||||
|
# session_id="case_001",
|
||||||
|
# user_msg="原告是谁?",
|
||||||
|
# ai_msg="原告是张三",
|
||||||
|
# metadata={"case_id": "case_001"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# store.add_conv(
|
||||||
|
# session_id="case_001",
|
||||||
|
# user_msg="借款金额是多少?",
|
||||||
|
# ai_msg="借款金额是50万元",
|
||||||
|
# metadata={"case_id": "case_001"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# store.add_conv(
|
||||||
|
# session_id="case_002",
|
||||||
|
# user_msg="原告是谁?",
|
||||||
|
# ai_msg="原告是李四",
|
||||||
|
# metadata={"case_id": "case_002"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# store.add_conv(
|
||||||
|
# session_id="case_002",
|
||||||
|
# user_msg="借款金额是多少?",
|
||||||
|
# ai_msg="借款金额是60万元",
|
||||||
|
# metadata={"case_id": "case_002"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# store.add_conv(
|
||||||
|
# session_id="case_002",
|
||||||
|
# user_msg="李四借款金额是多少?",
|
||||||
|
# ai_msg="李四借款金额是60万元",
|
||||||
|
# metadata={"case_id": "case_002"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# store.add_conv(
|
||||||
|
# session_id="case_002",
|
||||||
|
# user_msg="被告是谁?",
|
||||||
|
# ai_msg="李四",
|
||||||
|
# metadata={"case_id": "case_002"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# store.add_conv(
|
||||||
|
# session_id="case_002",
|
||||||
|
# user_msg="被告要还多少钱?",
|
||||||
|
# ai_msg="李四还60万",
|
||||||
|
# metadata={"case_id": "case_002"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# relevant = store.search_conv(
|
||||||
|
# session_id="case_002",
|
||||||
|
# query="被告需要还多少钱?",
|
||||||
|
# limit=3
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print("相关历史对话:")
|
||||||
|
# for conv in relevant:
|
||||||
|
# print(f" Q: {conv['user']}")
|
||||||
|
# print(f" A: {conv['ai']}")
|
||||||
|
# print(f" 相关度: {conv['score']:.3f}\n")
|
||||||
0
src/rag/__init__.py
Normal file
0
src/rag/__init__.py
Normal file
27
src/rag/chunker.py
Normal file
27
src/rag/chunker.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class TextChunker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
chunk_overlap: int = 50
|
||||||
|
):
|
||||||
|
self.splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separators=["\n\n", "\n", "。", "?", "!", ".", ",", "!", ";", "?", ";", " ", ""]
|
||||||
|
)
|
||||||
|
|
||||||
|
def chunk(self, text: str) -> List[Dict]:
|
||||||
|
chunks = self.splitter.split_text(text)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"text": chunk,
|
||||||
|
"chunk_id": i,
|
||||||
|
"total_chunks": len(chunks)
|
||||||
|
}
|
||||||
|
for i, chunk in enumerate(chunks)
|
||||||
|
]
|
||||||
38
src/rag/embeddings.py
Normal file
38
src/rag/embeddings.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
import httpx
|
||||||
|
from src.config import settings
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingClient:
|
||||||
|
def __init__(self):
|
||||||
|
http_client = httpx.Client(
|
||||||
|
proxy="http://127.0.0.1:10808",
|
||||||
|
verify=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=settings.openai_api_key,
|
||||||
|
base_url=settings.openai_base_url,
|
||||||
|
http_client=http_client
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
model: str = None
|
||||||
|
) -> List[List[float]]:
|
||||||
|
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=model or settings.embedding_model,
|
||||||
|
input=texts
|
||||||
|
)
|
||||||
|
|
||||||
|
return [item.embedding for item in response.data]
|
||||||
|
|
||||||
|
|
||||||
|
embedding_client = EmbeddingClient()
|
||||||
93
src/rag/loader.py
Normal file
93
src/rag/loader.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from typing import List, Dict
|
||||||
|
from src.rag.store import QdrantVecStore
|
||||||
|
from src.rag.chunker import TextChunker
|
||||||
|
from src.rag.embeddings import embedding_client
|
||||||
|
|
||||||
|
|
||||||
|
class DocLoader:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
chunk_overlap: int = 50
|
||||||
|
):
|
||||||
|
self.store = QdrantVecStore()
|
||||||
|
self.chunker = TextChunker(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
def loader_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
case_id: str,
|
||||||
|
source: str = "unknown"
|
||||||
|
) -> List[str]:
|
||||||
|
|
||||||
|
chunks = self.chunker.chunk(text)
|
||||||
|
|
||||||
|
texts = [chunk["text"] for chunk in chunks]
|
||||||
|
vectors = embedding_client.embed(texts)
|
||||||
|
|
||||||
|
metadatas = [
|
||||||
|
{
|
||||||
|
"case_id": case_id,
|
||||||
|
"source": source,
|
||||||
|
"chunk_id": chunk["chunk_id"],
|
||||||
|
"total_chunks": chunk["total_chunks"]
|
||||||
|
}
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
ids = self.store.add_documents(
|
||||||
|
texts=texts,
|
||||||
|
vectors=vectors,
|
||||||
|
metadatas=metadatas
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 案件 {case_id} 入库完成: {len(ids)} 个块")
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def loader_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
case_id: str
|
||||||
|
) -> List[str]:
|
||||||
|
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
return self.loader_text(
|
||||||
|
text=text,
|
||||||
|
case_id=case_id,
|
||||||
|
source=path.name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# loader = DocLoader()
|
||||||
|
|
||||||
|
# test_text = """
|
||||||
|
# 案件编号: case_001
|
||||||
|
# 原告: 张三
|
||||||
|
# 被告: 李四
|
||||||
|
|
||||||
|
# 原告诉讼请求:
|
||||||
|
# 1. 判令被告归还借款本金50万元
|
||||||
|
# 2. 支付利息3万元
|
||||||
|
# 3. 支付违约金2万元
|
||||||
|
# """
|
||||||
|
|
||||||
|
# loader.loader_text(
|
||||||
|
# text=test_text,
|
||||||
|
# case_id="case_001",
|
||||||
|
# source="complaint.txt"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print("已入库")
|
||||||
88
src/rag/retriever.py
Normal file
88
src/rag/retriever.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from src.rag.store import QdrantVecStore
|
||||||
|
from src.rag.embeddings import embedding_client
|
||||||
|
|
||||||
|
|
||||||
|
class Retriever:
|
||||||
|
def __init__(self):
|
||||||
|
self.store = QdrantVecStore()
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
case_id: Optional[str] = None,
|
||||||
|
top_k: int = 3
|
||||||
|
) -> List[Dict]:
|
||||||
|
query_vector = embedding_client.embed([query])[0]
|
||||||
|
|
||||||
|
results = self.store.search(
|
||||||
|
query_vector=query_vector,
|
||||||
|
case_id=case_id,
|
||||||
|
top_k=top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def format_context(self, results: List[Dict]) -> str:
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return "<documents>\n<message>没有找到相关文档。</message>\n</documents>"
|
||||||
|
|
||||||
|
context_parts = ["<documents>"]
|
||||||
|
|
||||||
|
|
||||||
|
max_score = max(r["score"] for r in results)
|
||||||
|
context_parts.append(
|
||||||
|
f"<summary>共检索到 {len(results)} 个相关文档片段,按相关度从高到低排序</summary>\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
source = result["metadata"].get("source", "未知来源")
|
||||||
|
text = result["text"]
|
||||||
|
score = result["score"]
|
||||||
|
|
||||||
|
if score >= 0.85:
|
||||||
|
relevance = "高度相关"
|
||||||
|
elif score >= 0.70:
|
||||||
|
relevance = "相关"
|
||||||
|
else:
|
||||||
|
relevance = "可能相关"
|
||||||
|
|
||||||
|
context_parts.append(
|
||||||
|
f'<document index="{i}" source="{source}" relevance="{relevance}" score="{score:.3f}">\n'
|
||||||
|
f'{text}\n'
|
||||||
|
f'</document>'
|
||||||
|
)
|
||||||
|
|
||||||
|
context_parts.append("</documents>")
|
||||||
|
|
||||||
|
return "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# retriever = Retriever()
|
||||||
|
|
||||||
|
# question = "被告李四的主要辩称理由是什么?"
|
||||||
|
|
||||||
|
# print(f" 检索问题: {question}\n")
|
||||||
|
|
||||||
|
# results = retriever.retrieve(
|
||||||
|
# query=question,
|
||||||
|
# case_id="case001",
|
||||||
|
# top_k=3
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print(f" 检索到 {len(results)} 个相关片段:\n")
|
||||||
|
|
||||||
|
# for i, result in enumerate(results, 1):
|
||||||
|
# print(f"--- 结果 {i} (score: {result['score']:.3f}) ---")
|
||||||
|
# print(f"内容: {result['text'][:100]}...")
|
||||||
|
# print(f"来源: {result['metadata']['source']}")
|
||||||
|
# print()
|
||||||
|
|
||||||
|
# context = retriever.format_context(results)
|
||||||
|
# print(f" 格式化上下文:\n{context}")
|
||||||
160
src/rag/store.py
Normal file
160
src/rag/store.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import (
|
||||||
|
Distance,
|
||||||
|
VectorParams,
|
||||||
|
PointStruct,
|
||||||
|
Filter,
|
||||||
|
FieldCondition,
|
||||||
|
MatchValue
|
||||||
|
)
|
||||||
|
from src.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVecStore:
|
||||||
|
def __init__(self, collection_name: str = None):
|
||||||
|
self.client = QdrantClient(
|
||||||
|
host=settings.qdrant_host,
|
||||||
|
port=settings.qdrant_port
|
||||||
|
)
|
||||||
|
self.collection_name = collection_name or settings.qdrant_collection
|
||||||
|
self._ensure_collection()
|
||||||
|
|
||||||
|
def _ensure_collection(self):
|
||||||
|
collections = self.client.get_collections().collections
|
||||||
|
exists = any(c.name == self.collection_name for c in collections)
|
||||||
|
|
||||||
|
if not exists:
|
||||||
|
self.client.create_collection(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=1536,
|
||||||
|
distance=Distance.COSINE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f" collection_name: {self.collection_name}")
|
||||||
|
|
||||||
|
def add_documents(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
vectors: list[list[float]],
|
||||||
|
metadatas: list[dict]
|
||||||
|
) -> list[str]:
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
points = []
|
||||||
|
ids = []
|
||||||
|
|
||||||
|
for txt, vec, meta in zip(texts, vectors, metadatas):
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
ids.append(doc_id)
|
||||||
|
|
||||||
|
points.append(
|
||||||
|
PointStruct(
|
||||||
|
id=doc_id,
|
||||||
|
vector=vec,
|
||||||
|
payload={
|
||||||
|
"text": txt,
|
||||||
|
"case_id": meta.get("case_id"),
|
||||||
|
"source": meta.get("source", "unknown"),
|
||||||
|
"page": meta.get("page"),
|
||||||
|
"chunk_id": meta.get("chunk_id"),
|
||||||
|
**meta # 其他数据
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client.upsert(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points=points
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 添加 {len(points)} 个文档片")
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query_vector: list[float],
|
||||||
|
case_id: str = None,
|
||||||
|
top_k: int = 3
|
||||||
|
) -> list[dict]:
|
||||||
|
query_filter = None
|
||||||
|
if case_id:
|
||||||
|
query_filter = Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="case_id",
|
||||||
|
match=MatchValue(value=case_id)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.client.query_points(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query=query_vector,
|
||||||
|
query_filter=query_filter,
|
||||||
|
limit=top_k,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
points = results.points if hasattr(results, 'points') else results
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"text": r.payload["text"],
|
||||||
|
"score": r.score,
|
||||||
|
"metadata": {
|
||||||
|
k: v for k, v in r.payload.items()
|
||||||
|
if k != "text"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for r in points
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete_by_case(self, case_id: str):
|
||||||
|
self.client.delete(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points_selector=Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="case_id",
|
||||||
|
match=MatchValue(value=case_id)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f" 删除案件 {case_id} 的所有文档")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
info = self.client.get_collection(self.collection_name)
|
||||||
|
coll = getattr(info, 'result', info)
|
||||||
|
|
||||||
|
total_docs = getattr(coll, "points_count", 0) or getattr(coll, "vectors_count", 0)
|
||||||
|
|
||||||
|
vector_size = 1536
|
||||||
|
config = getattr(coll, "config", None)
|
||||||
|
if config:
|
||||||
|
params = getattr(config, "params", None)
|
||||||
|
if params:
|
||||||
|
vectors_config = getattr(params, "vectors", None)
|
||||||
|
if vectors_config and hasattr(vectors_config, "size"):
|
||||||
|
vector_size = vectors_config.size
|
||||||
|
|
||||||
|
return {
|
||||||
|
"collection_name": self.collection_name,
|
||||||
|
"total_documents": int(total_docs),
|
||||||
|
"vector_size": int(vector_size),
|
||||||
|
"total_points": getattr(coll, "points_count", None),
|
||||||
|
"vectors_count": getattr(coll, "vectors_count", None),
|
||||||
|
"indexed_vectors_count": getattr(coll, "indexed_vectors_count", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# store = QdrantVecStore()
|
||||||
|
# print(f" 连接到 Qdrant: {settings.qdrant_host}:{settings.qdrant_port}")
|
||||||
|
# print(f" 集合: {store.collection_name}")
|
||||||
|
|
||||||
|
# stats = store.get_stats()
|
||||||
|
# print(f" 统计: {stats}")
|
||||||
270
src/tools/case_tools.py
Normal file
270
src/tools/case_tools.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, Optional, List
|
||||||
|
import re
|
||||||
|
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||||
|
from src.rag.store import QdrantVecStore
|
||||||
|
from src.mcp.mem import ConvoStore
|
||||||
|
|
||||||
|
|
||||||
|
class CaseTools:
|
||||||
|
@staticmethod
|
||||||
|
def get_case_metadata(case_id: str) -> Dict:
|
||||||
|
|
||||||
|
if not case_id or not case_id.strip():
|
||||||
|
return {
|
||||||
|
"error": "未提供案件ID",
|
||||||
|
"description": "请提供有效的案件ID,例如: 'case_001'"
|
||||||
|
}
|
||||||
|
|
||||||
|
store = QdrantVecStore()
|
||||||
|
|
||||||
|
collections = store.client.get_collections()
|
||||||
|
col_names = [col.name for col in collections.collections]
|
||||||
|
if not col_names:
|
||||||
|
return {
|
||||||
|
"error": f"Qdrant链接失败",
|
||||||
|
"description": "请检查Qdrant是否启动"
|
||||||
|
}
|
||||||
|
collection_name = getattr(store, "collection_name", None) or getattr(store, "collection_name", None)
|
||||||
|
if collection_name not in col_names:
|
||||||
|
return {
|
||||||
|
"error": f"未找到集合: {collection_name}",
|
||||||
|
"description": "请确认案件文档是否已入库"
|
||||||
|
}
|
||||||
|
|
||||||
|
results = store.client.scroll(
|
||||||
|
collection_name = collection_name,
|
||||||
|
scroll_filter=Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="case_id",
|
||||||
|
match=MatchValue(value=case_id)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
limit=100,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False
|
||||||
|
)
|
||||||
|
if not results[0]:
|
||||||
|
return {
|
||||||
|
"error": f"未找到案件 {case_id} 的文档",
|
||||||
|
"description": "请确认案件ID是否正确,或该案件是否已录入系统"
|
||||||
|
}
|
||||||
|
|
||||||
|
target = results[0] if results else []
|
||||||
|
if not target:
|
||||||
|
return {
|
||||||
|
"error": f"未找到案件 {case_id} 的文档",
|
||||||
|
"case_id": case_id,
|
||||||
|
"description": "该案件不存在或尚未录入系统"
|
||||||
|
}
|
||||||
|
|
||||||
|
all_source = set()
|
||||||
|
chunk_cnt = len(target)
|
||||||
|
for point in target:
|
||||||
|
payload = point.payload if hasattr(point, 'payload') else {}
|
||||||
|
if payload and "source" in payload:
|
||||||
|
all_source.add(payload["source"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"case_id": case_id,
|
||||||
|
"doc_cnt": len(all_source),
|
||||||
|
"chunk_cnt": chunk_cnt,
|
||||||
|
"sources": sorted(list(all_source)),
|
||||||
|
"status":"已入库",
|
||||||
|
"description": f"案件 {case_id} 共包含 {len(all_source)} 个文档,共 {chunk_cnt} 个块。"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_rag_prompt(
|
||||||
|
query: str,
|
||||||
|
context: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
max_history: int = 5,
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
) -> Dict:
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
sour_his = None
|
||||||
|
conv_his = []
|
||||||
|
|
||||||
|
|
||||||
|
def_sys_prompt = """你是一个专业的法律案件分析 AI 助手。
|
||||||
|
|
||||||
|
<role>
|
||||||
|
你的职责:
|
||||||
|
1. 基于检索到的案件文档片段,准确回答用户的问题
|
||||||
|
2. 结合对话历史理解上下文(如代词指代、前置问题等)
|
||||||
|
3. 提供客观、专业、有依据的法律分析
|
||||||
|
4. 明确标注信息来源,便于用户核实
|
||||||
|
</role>
|
||||||
|
|
||||||
|
<guidelines>
|
||||||
|
核心原则:
|
||||||
|
- 忠于文档内容: 仅基于提供的文档回答,不得编造或推测未在文档中体现的信息
|
||||||
|
- 明确引用来源: 回答时使用 [文档X] 标注信息来源(X 为文档索引号)
|
||||||
|
- 区分确定与不确定: 如果文档信息不足以明确回答,应诚实说明
|
||||||
|
- 结合对话历史: 理解代词指代、关联前置问题、维持话题连贯性
|
||||||
|
- 专业法律表达: 使用准确的法律术语,避免模糊或口语化表述
|
||||||
|
|
||||||
|
回答格式:
|
||||||
|
1. 直接回答问题(引用文档)
|
||||||
|
2. 如需要,提供补充说明或法律解释
|
||||||
|
3. 如果信息不足,明确说明缺失的部分
|
||||||
|
</guidelines>"""
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt or def_sys_prompt
|
||||||
|
})
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
store = ConvoStore()
|
||||||
|
rel_convs = store.search_conv(
|
||||||
|
session_id=session_id,
|
||||||
|
query=query,
|
||||||
|
limit=max_history
|
||||||
|
)
|
||||||
|
if rel_convs and len(rel_convs) > 0:
|
||||||
|
rel_convs.sort(key=lambda x: x.get('timestep',''))
|
||||||
|
for conv in rel_convs:
|
||||||
|
conv_his.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": conv["user"]
|
||||||
|
})
|
||||||
|
conv_his.append({
|
||||||
|
"role": "ai",
|
||||||
|
"content": conv["ai"]
|
||||||
|
})
|
||||||
|
sour_his = "relevant"
|
||||||
|
|
||||||
|
else:
|
||||||
|
rec_convs = store.get_all_conv(
|
||||||
|
session_id=session_id,
|
||||||
|
limit=max_history
|
||||||
|
)
|
||||||
|
if rec_convs and len(rec_convs) > 0:
|
||||||
|
rec_convs = rec_convs[:max_history]
|
||||||
|
rec_convs.reverse()
|
||||||
|
for conv in rec_convs:
|
||||||
|
conv_his.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": conv["user"]
|
||||||
|
})
|
||||||
|
conv_his.append({
|
||||||
|
"role": "ai",
|
||||||
|
"content": conv["ai"]
|
||||||
|
})
|
||||||
|
sour_his = "recent"
|
||||||
|
|
||||||
|
if conv_his and len(conv_his) > 0:
|
||||||
|
messages.extend(conv_his)
|
||||||
|
|
||||||
|
user_msg = []
|
||||||
|
if conv_his and len(conv_his) > 0:
|
||||||
|
user_msg.append("""<instruction>
|
||||||
|
请结合上述对话历史理解当前问题。注意识别代词指代和话题延续。
|
||||||
|
</instruction>
|
||||||
|
""")
|
||||||
|
|
||||||
|
user_msg.append(f"""<context>
|
||||||
|
{context}
|
||||||
|
</context>
|
||||||
|
|
||||||
|
<question>
|
||||||
|
{query}
|
||||||
|
</question>
|
||||||
|
|
||||||
|
请根据上述文档回答问题。记得引用具体的文档编号。""")
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": "\n".join(user_msg)
|
||||||
|
})
|
||||||
|
|
||||||
|
his_len = len(conv_his) // 2 if conv_his else 0
|
||||||
|
return {
|
||||||
|
"messages": messages,
|
||||||
|
"prompt_stats": {
|
||||||
|
"total_messages": len(messages),
|
||||||
|
"has_context": bool(context),
|
||||||
|
"has_history": bool(conv_his),
|
||||||
|
"history_length": his_len,
|
||||||
|
"history_source": sour_his,
|
||||||
|
"session_id": session_id
|
||||||
|
},
|
||||||
|
"info": {
|
||||||
|
"query": query,
|
||||||
|
"session_id": session_id,
|
||||||
|
"history_used": his_len
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_timeline(documents: List[Dict]) -> Dict:
|
||||||
|
events = []
|
||||||
|
all_text = "\n".join([doc.get("text", "") for doc in documents])
|
||||||
|
|
||||||
|
date_pattern = r'(\d{4})[年/-](\d{1,2})[月/-](\d{1,2})日?'
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
text = doc.get("text", "")
|
||||||
|
source = doc.get("metadata", {}).get("source", "unknown")
|
||||||
|
|
||||||
|
matches = re.finditer(date_pattern, text)
|
||||||
|
for match in matches:
|
||||||
|
year, month, day = match.groups()
|
||||||
|
date_str = f"{year}-{month.zfill(2)}-{day.zfill(2)}"
|
||||||
|
|
||||||
|
start = max(0, match.start() - 30)
|
||||||
|
end = min(len(text), match.end() + 50)
|
||||||
|
context = text[start:end].strip()
|
||||||
|
|
||||||
|
event_type = "事件"
|
||||||
|
if "借款" in context or "出借" in context:
|
||||||
|
event_type = "借款发生"
|
||||||
|
elif "到期" in context:
|
||||||
|
event_type = "借款到期"
|
||||||
|
elif "起诉" in context or "立案" in context:
|
||||||
|
event_type = "提起诉讼"
|
||||||
|
elif "开庭" in context or "审理" in context:
|
||||||
|
event_type = "开庭审理"
|
||||||
|
elif "判决" in context:
|
||||||
|
event_type = "作出判决"
|
||||||
|
|
||||||
|
events.append({
|
||||||
|
"date": date_str,
|
||||||
|
"event": event_type,
|
||||||
|
"context": context,
|
||||||
|
"source": source
|
||||||
|
})
|
||||||
|
|
||||||
|
events.sort(key=lambda x: x["date"])
|
||||||
|
|
||||||
|
duration_days = 0
|
||||||
|
earliest = None
|
||||||
|
latest = None
|
||||||
|
|
||||||
|
if events:
|
||||||
|
earliest = events[0]["date"]
|
||||||
|
latest = events[-1]["date"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_date = datetime.strptime(earliest, "%Y-%m-%d")
|
||||||
|
end_date = datetime.strptime(latest, "%Y-%m-%d")
|
||||||
|
duration_days = (end_date - start_date).days
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timeline": events[:20],
|
||||||
|
"total_events": len(events),
|
||||||
|
"duration_days": duration_days,
|
||||||
|
"earliest": earliest,
|
||||||
|
"latest": latest,
|
||||||
|
"summary": f"案件时间跨度{duration_days}天,从{earliest}到{latest}" if earliest and latest else "未找到时间信息"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
tools = CaseTools()
|
||||||
23
start_claude.ps1
Normal file
23
start_claude.ps1
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
$env:ANTHROPIC_BASE_URL = "https://api.moonshot.cn/anthropic"
|
||||||
|
$env:MOONSHOT_API_KEY = "sk-joeLiNi2DzTWHr06tdFOdx3A8eKg98eowwShkTlusowyP4vS"
|
||||||
|
$env:ANTHROPIC_AUTH_TOKEN = $env:MOONSHOT_API_KEY
|
||||||
|
|
||||||
|
$env:ANTHROPIC_MODEL = "kimi-k2-thinking-turbo"
|
||||||
|
$env:ANTHROPIC_DEFAULT_OPUS_MODEL = "kimi-k2-thinking-turbo"
|
||||||
|
$env:ANTHROPIC_DEFAULT_SONNET_MODEL = "kimi-k2-thinking-turbo"
|
||||||
|
$env:ANTHROPIC_DEFAULT_HAIKU_MODEL = "kimi-k2-thinking-turbo"
|
||||||
|
$env:CLAUDE_CODE_SUBAGENT_MODEL = "kimi-k2-thinking-turbo"
|
||||||
|
|
||||||
|
Write-Host " 环境变量已设置" -ForegroundColor Green
|
||||||
|
Write-Host " Base URL: $env:ANTHROPIC_BASE_URL" -ForegroundColor Cyan
|
||||||
|
Write-Host " Model: $env:ANTHROPIC_MODEL" -ForegroundColor Cyan
|
||||||
|
|
||||||
|
$claudePath = "C:\Users\ADMIN\AppData\Local\AnthropicClaude\claude.exe"
|
||||||
|
|
||||||
|
if (Test-Path $claudePath) {
|
||||||
|
Write-Host " 正在启动 Claude Desktop..." -ForegroundColor Yellow
|
||||||
|
Start-Process $claudePath
|
||||||
|
} else {
|
||||||
|
Write-Host " 找不到 Claude.exe,请检查安装路径" -ForegroundColor Red
|
||||||
|
}
|
||||||
2
tests/__init__.py
Normal file
2
tests/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
"""tests package"""
|
||||||
|
|
||||||
202
tests/test_mcp.py
Normal file
202
tests/test_mcp.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
import json
|
||||||
|
from src.mcp.mem import ConvoStore
|
||||||
|
from src.tools.case_tools import CaseTools
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv_store():
|
||||||
|
store = ConvoStore()
|
||||||
|
|
||||||
|
test_session = "CASE_SESSION_001"
|
||||||
|
|
||||||
|
store.add_conv(
|
||||||
|
session_id=test_session,
|
||||||
|
user_msg="这个案件的原告是谁?",
|
||||||
|
ai_msg="根据案件文档,原告是张三。",
|
||||||
|
metadata={"case_id": "CASE_001"}
|
||||||
|
)
|
||||||
|
|
||||||
|
store.add_conv(
|
||||||
|
session_id=test_session,
|
||||||
|
user_msg="借款金额是多少?",
|
||||||
|
ai_msg="借款本金为50万元,约定年利率10%。",
|
||||||
|
metadata={"case_id": "CASE_001"}
|
||||||
|
)
|
||||||
|
|
||||||
|
store.add_conv(
|
||||||
|
session_id=test_session,
|
||||||
|
user_msg="被告的答辩意见是什么?",
|
||||||
|
ai_msg="被告称该笔款项为合伙投资款,非借款,并称借条系伪造。",
|
||||||
|
metadata={"case_id": "CASE_001"}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 对话存储成功")
|
||||||
|
print(f" 会话ID: {test_session}")
|
||||||
|
print(f" 对话轮数: 3")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv_search():
|
||||||
|
|
||||||
|
store = ConvoStore()
|
||||||
|
test_session = "CASE_SESSION_001"
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"原告信息",
|
||||||
|
"借了多少钱",
|
||||||
|
"被告怎么说"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
results = store.search_conv(
|
||||||
|
session_id=test_session,
|
||||||
|
query=query,
|
||||||
|
limit=2
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n 查询: {query}")
|
||||||
|
print(f" 结果数: {len(results)}")
|
||||||
|
|
||||||
|
if results:
|
||||||
|
print(f" 最高相关度: {results[0]['score']:.3f}")
|
||||||
|
print(f" 匹配对话: Q: {results[0]['user'][:30]}...")
|
||||||
|
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
|
print(f"\n 检索成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv_all():
|
||||||
|
|
||||||
|
store = ConvoStore()
|
||||||
|
test_session = "CASE_SESSION_001"
|
||||||
|
|
||||||
|
all_convs = store.get_all_conv(
|
||||||
|
session_id=test_session,
|
||||||
|
limit=100
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 会话获取成功")
|
||||||
|
print(f" 会话ID: {test_session}")
|
||||||
|
print(f" 对话总数: {len(all_convs)}")
|
||||||
|
|
||||||
|
if all_convs:
|
||||||
|
print(f"\n 最近对话:")
|
||||||
|
for i, conv in enumerate(all_convs[:2], 1):
|
||||||
|
print(f" {i}. Q: {conv['user_msg'][:40]}...")
|
||||||
|
print(f" A: {conv['ai_msg'][:40]}...")
|
||||||
|
|
||||||
|
assert len(all_convs) >= 3
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_case_tools_metadata():
|
||||||
|
|
||||||
|
tools = CaseTools()
|
||||||
|
result = tools.get_case_metadata("CASE_001")
|
||||||
|
|
||||||
|
print(f" 查询完成")
|
||||||
|
print(f" 案件ID: CASE_001")
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
print(f" 状态: {result['error']}")
|
||||||
|
print(f" 说明: {result.get('description', '')}")
|
||||||
|
else:
|
||||||
|
print(f" 文档数: {result.get('doc_cnt', 0)}")
|
||||||
|
print(f" 片段数: {result.get('chunk_cnt', 0)}")
|
||||||
|
print(f" 状态: {result.get('status', '')}")
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_rag_prompt():
|
||||||
|
tools = CaseTools()
|
||||||
|
|
||||||
|
mock_context = """<documents>
|
||||||
|
<document index="1" source="complaint.txt">
|
||||||
|
原告: 张三
|
||||||
|
被告: 李四
|
||||||
|
借款金额: 50万元
|
||||||
|
</document>
|
||||||
|
<document index="2" source="defense.txt">
|
||||||
|
被告答辩: 该笔款项为合伙投资款
|
||||||
|
</document>
|
||||||
|
</documents>"""
|
||||||
|
|
||||||
|
result = tools.build_rag_prompt(
|
||||||
|
query="原告和被告分别是谁?",
|
||||||
|
context=mock_context,
|
||||||
|
session_id="CASE_SESSION_001",
|
||||||
|
max_history=2
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" prompt构建成功")
|
||||||
|
print(f" 消息总数: {result['prompt_stats']['total_messages']}")
|
||||||
|
print(f" 包含上下文: {result['prompt_stats']['has_context']}")
|
||||||
|
print(f" 包含历史: {result['prompt_stats']['has_history']}")
|
||||||
|
print(f" 历史来源: {result['prompt_stats']['history_source']}")
|
||||||
|
print(f" 历史轮数: {result['prompt_stats']['history_length']}")
|
||||||
|
|
||||||
|
assert len(result['messages']) > 0
|
||||||
|
assert result['prompt_stats']['has_context']
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_cleanup():
|
||||||
|
store = ConvoStore()
|
||||||
|
test_session = "CASE_SESSION_001"
|
||||||
|
|
||||||
|
all_convs = store.get_all_conv(session_id=test_session, limit=1000)
|
||||||
|
print(f" 清理前对话数: {len(all_convs)}")
|
||||||
|
|
||||||
|
count = store.delete_by_session(test_session)
|
||||||
|
|
||||||
|
print(f" 清理完成")
|
||||||
|
print(f" 测试会话: {test_session}")
|
||||||
|
print(f" 删除对话数: {count}")
|
||||||
|
|
||||||
|
remaining = store.get_all_conv(session_id=test_session, limit=1000)
|
||||||
|
print(f" 清理后对话数: {len(remaining)}")
|
||||||
|
|
||||||
|
assert len(remaining) == 0
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("MCP 测试")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
("对话存储", test_conv_store),
|
||||||
|
("对话检索", test_conv_search),
|
||||||
|
("获取全部对话", test_conv_all),
|
||||||
|
("获取案件元数据", test_case_tools_metadata),
|
||||||
|
("构建 RAG Prompt", test_build_rag_prompt),
|
||||||
|
("清理测试数据", test_cleanup)
|
||||||
|
]
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for name, test_func in tests:
|
||||||
|
if test_func():
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
print(f"\n 测试失败: {name}")
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print(f"结果: {passed} 通过, {failed} 失败")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
return failed == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = run_all_tests()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
182
tests/test_rag.py
Normal file
182
tests/test_rag.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from src.rag.loader import DocLoader
|
||||||
|
from src.rag.retriever import Retriever
|
||||||
|
from src.rag.store import QdrantVecStore
|
||||||
|
from src.rag.embeddings import embedding_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect():
|
||||||
|
store = QdrantVecStore()
|
||||||
|
stats = store.get_stats()
|
||||||
|
|
||||||
|
print(f" Qdrant连接成功")
|
||||||
|
print(f" 集合: {stats['collection_name']}")
|
||||||
|
print(f" 文档数: {stats['total_documents']}")
|
||||||
|
print(f" 向量维度: {stats['vector_size']}")
|
||||||
|
|
||||||
|
assert stats['collection_name'] is not None
|
||||||
|
assert stats['total_documents'] >= 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed():
|
||||||
|
texts = ["测试文本1", "测试文本2"]
|
||||||
|
vectors = embedding_client.embed(texts)
|
||||||
|
|
||||||
|
print(f" txt1,txt2 embed成功")
|
||||||
|
print(f" 输入文本数: {len(texts)}")
|
||||||
|
print(f" 输出向量数: {len(vectors)}")
|
||||||
|
print(f" 向量维度: {len(vectors[0])}")
|
||||||
|
|
||||||
|
assert len(vectors) == len(texts)
|
||||||
|
assert len(vectors[0]) == 1536
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_load():
|
||||||
|
|
||||||
|
loader = DocLoader(chunk_size=200, chunk_overlap=50)
|
||||||
|
|
||||||
|
test_text = """
|
||||||
|
案件编号: CASE_001
|
||||||
|
案件名称: 测试案件 - 借款合同纠纷
|
||||||
|
|
||||||
|
原告: 张三
|
||||||
|
被告: 李四
|
||||||
|
|
||||||
|
案情简介:
|
||||||
|
原告张三于2023年1月向被告李四出借人民币50万元,约定年利率10%,
|
||||||
|
借款期限为1年。到期后被告未按约定归还本金及利息。
|
||||||
|
|
||||||
|
原告诉讼请求:
|
||||||
|
1. 判令被告归还借款本金50万元
|
||||||
|
2. 支付利息5万元
|
||||||
|
3. 承担本案诉讼费用
|
||||||
|
|
||||||
|
被告答辩:
|
||||||
|
该笔款项为合伙投资款,非借款。原告提供的借条系伪造。
|
||||||
|
"""
|
||||||
|
|
||||||
|
ids = loader.loader_text(
|
||||||
|
text=test_text,
|
||||||
|
case_id="CASE_001",
|
||||||
|
source="test_complaint.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" doc入库成功")
|
||||||
|
print(f" 案件ID: CASE_001")
|
||||||
|
print(f" 文档片段数: {len(ids)}")
|
||||||
|
print(f" 片段ID示例: {ids[0][:8]}...")
|
||||||
|
|
||||||
|
assert len(ids) > 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_search():
|
||||||
|
|
||||||
|
retriever = Retriever()
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"原告是谁?",
|
||||||
|
"借款金额是多少?",
|
||||||
|
"被告的答辩意见是什么?"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
results = retriever.retrieve(
|
||||||
|
query=query,
|
||||||
|
case_id="CASE_001",
|
||||||
|
top_k=2
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n 查询: {query}")
|
||||||
|
print(f" 结果数: {len(results)}")
|
||||||
|
|
||||||
|
if results:
|
||||||
|
print(f" 最高相似度: {results[0]['score']:.3f}")
|
||||||
|
print(f" 片段预览: {results[0]['text'][:50]}...")
|
||||||
|
|
||||||
|
assert len(results) > 0
|
||||||
|
assert results[0]['score'] > 0
|
||||||
|
|
||||||
|
print(f"\n 检索成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_format():
|
||||||
|
retriever = Retriever()
|
||||||
|
|
||||||
|
results = retriever.retrieve(
|
||||||
|
query="案件的原被告信息",
|
||||||
|
case_id="CASE_001",
|
||||||
|
top_k=3
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = retriever.format_context(results)
|
||||||
|
|
||||||
|
print(f" 格式化成功")
|
||||||
|
print(f" 原始结果数: {len(results)}")
|
||||||
|
print(f" 格式化长度: {len(formatted)} 字符")
|
||||||
|
print(f"\n 格式化预览:")
|
||||||
|
print(f" {formatted[:200]}...")
|
||||||
|
|
||||||
|
assert "<documents>" in formatted
|
||||||
|
assert "</documents>" in formatted
|
||||||
|
assert len(formatted) > 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup():
|
||||||
|
store = QdrantVecStore()
|
||||||
|
count = store.delete_by_case("CASE_001")
|
||||||
|
|
||||||
|
print(f" 清理完成")
|
||||||
|
if count is not None:
|
||||||
|
print(f" 删除文档数: {count}")
|
||||||
|
assert count >= 0
|
||||||
|
else:
|
||||||
|
print(f" 删除操作已执行(未返回计数)")
|
||||||
|
|
||||||
|
stats = store.get_stats()
|
||||||
|
print(f" 剩余文档数: {stats['total_documents']}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("RAG测试")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
("Qdrant 连接", test_connect),
|
||||||
|
("嵌入客户端", test_embed),
|
||||||
|
("文档入库", test_doc_load),
|
||||||
|
("文档检索", test_doc_search),
|
||||||
|
("上下文清除", test_text_format),
|
||||||
|
("清理数据", test_cleanup)
|
||||||
|
]
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for name, test_func in tests:
|
||||||
|
if test_func():
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
print(f"\n 测试失败: {name}")
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print(f"结果: {passed} 通过, {failed} 失败")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
return failed == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = run_all_tests()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
Reference in New Issue
Block a user