This commit is contained in:
2025-11-28 15:06:54 +08:00
commit 0bb8b4e13b
21 changed files with 1744 additions and 0 deletions

25
.gitignore vendored Normal file
View File

@@ -0,0 +1,25 @@
venv/
.venv/
__pycache__/
*.pyc
.env
.env.local
*.log
dist/
build/
*.egg-info/
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
data/
.vscode/
.claude/
# Virtual environments
.venv/
aiqs/

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.13

0
README.md Normal file
View File

6
main.py Normal file
View File

@@ -0,0 +1,6 @@
def main():
print("Hello from aiqs!")
if __name__ == "__main__":
main()

12
pyproject.toml Normal file
View File

@@ -0,0 +1,12 @@
[project]
name = "aiqs"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"openai==2.7.1",
"python-dotenv>=1.2.1",
"pydantic>=2.12.4",
"pydantic-settings>=2.11.0",
]

9
requirement.txt Normal file
View File

@@ -0,0 +1,9 @@
openai>=1.0.0
pydantic>=2.0.0
pydantic-settings>=2.0.0
httpx>=0.24.0
qdrant-client>=1.7.0
tiktoken>=0.5.0
pillow>=10.0.0

0
src/__init__.py Normal file
View File

33
src/config.py Normal file
View File

@@ -0,0 +1,33 @@
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
PROJECT_ROOT = Path(__file__).parent.parent
class Settings(BaseSettings):
openai_api_key: str = Field(default="", description="key")
openai_base_url: str = Field(
default="https://api.openai.com/v1",
description="url",
)
chat_model: str = Field(
default="gpt-4o-mini"
)
embedding_model: str = Field(
default="text-embedding-3-small"
)
qdrant_host: str = Field(default="localhost")
qdrant_port: int = Field(default=6333)
qdrant_collection: str = Field(
default="legal_documents"
)
model_config = SettingsConfigDict(
env_file=PROJECT_ROOT / ".env",
env_file_encoding="utf-8",
extra="ignore",
)
settings = Settings()

334
src/mcp/case_server.py Normal file
View File

@@ -0,0 +1,334 @@
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp import types
import json
from src.rag.retriever import Retriever
from src.tools.case_tools import tools
from src.mcp.mem import ConvoStore
app = Server("legal-case-server")
retriever = Retriever()
conv = ConvoStore()
@app.list_tools()
async def list_tools() -> list[types.Tool]:
return [
types.Tool(
name="search_case_documents",
description="在法律案件文档中检索相关信息。用于回答关于案件事实、当事人信息、诉讼请求、判决结果等问题。",
inputSchema={
"type": "object",
"properties": {
"case_id": {
"type": "string",
"description": "案件ID'case_001'"
},
"query": {
"type": "string",
"description": "检索问题,如 '原告的诉讼请求是什么?'"
},
"top_k": {
"type": "integer",
"description": "返回结果数量默认3",
"default": 3
}
},
"required": ["case_id", "query"]
}
),
types.Tool(
name="get_case_metadata",
description="获取案件的基本信息",
inputSchema={
"type": "object",
"properties": {
"case_id": {
"type": "string",
"description": "案件ID"
}
},
"required": ["case_id"]
}
),
types.Tool(
name="build_rag_prompt",
description="构建标准化的 RAG Prompt整合检索上下文和对话历史。返回可直接用于 LLM 的 messages 数组。",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "用户问题"
},
"context": {
"type": "string",
"description": "检索到的文档上下文XML 格式)"
},
"conversation_history": {
"type": "array",
"description": "对话历史(可选)",
"items": {
"type": "object",
"properties": {
"role": {"type": "string"},
"content": {"type": "string"}
}
}
},
"system_prompt": {
"type": "string",
"description": "自定义系统提示词(可选)"
}
},
"required": ["query", "context"]
}
),
types.Tool(
name="gen_timeline",
description="从文档中生成案件时间线,自动提取并排序时间信息",
inputSchema={
"type": "object",
"properties": {
"documents": {
"type": "array",
"description": "检索结果文档列表",
"items": {
"type": "object",
"properties": {
"text": {"type": "string"},
"metadata": {"type": "object"}
}
}
}
},
"required": ["documents"]
}
),
types.Tool(
name="store_conservation",
description="存储对话到会话记录中",
inputSchema={
"type": "object",
"properties": {
"session_id": {"type": "string", "description": "会话ID"},
"user_message": {"type": "string", "description": "用户消息"},
"ai_message": {"type": "string", "description": "AI回复"},
"metadata": {"type": "object", "description": "额外元数据"}
},
"required": ["session_id", "user_message", "ai_message"]
}
),
types.Tool(
name="search_conversation",
description="检索会话中的相关对话记录",
inputSchema={
"type": "object",
"properties": {
"session_id": {"type": "string"},
"query": {"type": "string", "description": "搜索查询"},
"limit": {"type": "integer","default": 5}
},
"required": ["session_id", "query"]
}
),
types.Tool(
name="get_all_conversation",
description="获取所有会话记录",
inputSchema={
"type": "object",
"properties": {
"session_id": {"type": "string", "description": "会话ID"},
"limit": {"type": "integer","default": 100}
},
"required": ["session_id"]
}
)
]
@app.list_resources()
async def list_resources() -> list[types.Resource]:
return []
@app.list_prompts()
async def list_prompts() -> list[types.Prompt]:
return [
types.Prompt(
name="analyze_case",
description="分析法律案件的标准流程",
arguments=[
types.PromptArgument(
name="case_id",
description="案件ID",
required=True
)
]
)
]
@app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[types.TextContent]:
if name == "search_case_documents":
case_id = arguments["case_id"]
query = arguments["query"]
top_k = arguments.get("top_k", 3)
results = retriever.retrieve(
query=query,
case_id=case_id,
top_k=top_k
)
if not results:
response = {
"status": "no_results",
"message": f"在案件 {case_id} 中未找到与 '{query}' 相关的文档"
}
else:
context = retriever.format_context(results)
response = {
"status": "success",
"case_id": case_id,
"query": query,
"results_count": len(results),
"context": context,
"documents": [
{
"index": i + 1,
"text": r["text"][:300],
"score": round(r["score"], 3),
"source": r["metadata"].get("source", "unknown")
}
for i, r in enumerate(results)
]
}
return [types.TextContent(
type="text",
text=json.dumps(response, ensure_ascii=False, indent=2)
)]
elif name == "get_case_metadata":
result = tools.get_case_metadata(**arguments)
return [types.TextContent(
type="text",
text=json.dumps(result, ensure_ascii=False, indent=2)
)]
elif name == "build_rag_prompt":
result = tools.build_rag_prompt(**arguments)
return [types.TextContent(
type="text",
text=json.dumps(result, ensure_ascii=False, indent=2)
)]
elif name == "gen_timeline":
result = tools.gen_timeline(**arguments)
return [types.TextContent(
type="text",
text=json.dumps(result, ensure_ascii=False, indent=2)
)]
elif name == "store_conservation":
session_id = arguments.get("session_id")
user_msg = arguments.get("user_msg")
ai_msg = arguments.get("ai_msg")
metadata = arguments.get("metadata", {})
conv.add_conv(
session_id=session_id,
user_msg=user_msg,
ai_msg=ai_msg,
metadata=metadata
)
return [types.TextContent(
type="text",
text=f"对话以存储到会话{session_id}中。"
)]
elif name == "search_conversation":
session_id = arguments.get("session_id")
query = arguments.get("query")
limit = arguments.get("limit", 5)
results = conv.search_conv(
session_id=session_id,
query=query,
limit=limit
)
formatted = "\n\n".join([
f"**Q**: {r['user']}\n**A**: {r['assistant']}\n(相关度: {r['score']:.2f})"
for r in results
])
return [types.TextContent(
type="text",
text=formatted or "未找到相关对话记录。"
)]
elif name == "get_all_conversation":
session_id = arguments.get("session_id")
limit = arguments.get("limit", 100)
results = conv.get_all_conv(
session_id=session_id,
limit=limit
)
formatted = "\n\n".join([
f"**Q**: {r['user']}\n**A**: {r['assistant']}"
for r in results
])
return [types.TextContent(
type="text",
text=formatted or "该会话暂无对话记录。"
)]
raise ValueError(f"未知工具: {name}")
@app.get_prompt()
async def get_prompt(name: str, arguments: dict) -> types.GetPromptResult:
if name == "analyze_case":
case_id = arguments["case_id"]
return types.GetPromptResult(
description=f"分析案件 {case_id}",
messages=[
types.PromptMessage(
role="user",
content=types.TextContent(
type="text",
text=f"请分析案件 {case_id}包括1. 案情摘要 2. 争议焦点 3. 法律依据"
)
)
]
)
@app.read_resource()
async def read_resource(uri: str) -> str:
mapping = {
"case://case_001/summary": Path("d:/workspace/aiqs/data/case_001_summary.txt")
}
path = mapping.get(uri)
if path and path.exists():
return path.read_text(encoding="utf-8")
return f"未列出的资源或不存在: {uri}"
async def main():
async with stdio_server() as (read_stream, write_stream):
await app.run(
read_stream,
write_stream,
app.create_initialization_options()
)
if __name__ == "__main__":
asyncio.run(main())

239
src/mcp/mem.py Normal file
View File

@@ -0,0 +1,239 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from typing import List, Dict, Optional
import uuid
from datetime import datetime
from src.rag.embeddings import embedding_client
class ConvoStore:
def __init__(
self,
collection_name: str = "conversations",
qdrant_url: str = "http://localhost:6333"
):
self.client = QdrantClient(url=qdrant_url)
self.collection_name = collection_name
if not collection_name in [c.name for c in self.client.get_collections().collections]:
self.client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=1536,
distance=Distance.COSINE
)
)
else:
self.client.get_collection(collection_name)
def add_conv(
self,
session_id: str,
user_msg: str,
ai_msg: str,
metadata: Optional[Dict] = None
):
if not session_id or not user_msg or not ai_msg:
return None
# 用户问+ai答 为一条记录(还是分开?)
conv_text = f"用户: {user_msg}\nAI: {ai_msg}"
vector = embedding_client.embed([conv_text])[0]
payload = {
"session_id": session_id,
"user_msg": user_msg,
"ai_msg": ai_msg,
"timestamp": datetime.now().isoformat()
}
if metadata:
payload.update(metadata)
point_id = str(uuid.uuid4())
self.client.upsert(
collection_name=self.collection_name,
points=[
PointStruct(
id=point_id,
vector=vector,
payload=payload
)
]
)
return point_id
def search_conv(
self,
session_id: str,
query: str,
limit: int = 5
) -> List[Dict]:
if not session_id or not query:
return []
query_vector = embedding_client.embed([query])[0]
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
query_filter={
"must": [
{"key": "session_id", "match": {"value": session_id}}
]
},
limit=limit,
with_payload=True
)
points = results.points if hasattr(results, "points") else []
if not points:
return []
convs = []
for point in points:
payload = point.payload if hasattr(point, "payload") else {}
if not payload:
continue
convs.append({
"user": payload.get("user_msg",""),
"ai": payload.get("ai_msg",""),
"timestamp": payload.get("timestamp",""),
"score": point.score if hasattr(point, "score") else 0.0
})
return convs
def get_all_conv(
self,
session_id: str,
limit: int = 100
) -> List[Dict]:
if not session_id:
return []
results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter={
"must": [
{"key": "session_id",
"match": {"value": session_id}}
]
},
limit=limit,
with_payload=True,
with_vectors=False
)
points = results[0] if results and len(results) > 0 else []
if not points:
return []
convs = []
for point in points:
payload = point.payload if hasattr(point, "payload") else {}
if not payload:
continue
convs.append({
"user_msg": payload.get("user_msg",""),
"ai_msg": payload.get("ai_msg",""),
"timestamp": payload.get("timestamp","")
})
convs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
return convs
def delete_by_session(self, session_id: str) -> int:
if not session_id:
return 0
results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter={
"must": [
{"key": "session_id",
"match": {"value": session_id}}
]
},
limit=10000,
with_payload=False,
with_vectors=False
)
points = results[0] if results and len(results) > 0 else []
if not points:
return 0
point_ids = [point.id for point in points]
self.client.delete(
collection_name=self.collection_name,
points_selector=point_ids
)
return len(point_ids)
# # def qdrant_example():
# if __name__ == "__main__":
# store = ConvoStore()
# store.add_conv(
# session_id="case_001",
# user_msg="原告是谁?",
# ai_msg="原告是张三",
# metadata={"case_id": "case_001"}
# )
# store.add_conv(
# session_id="case_001",
# user_msg="借款金额是多少?",
# ai_msg="借款金额是50万元",
# metadata={"case_id": "case_001"}
# )
# store.add_conv(
# session_id="case_002",
# user_msg="原告是谁?",
# ai_msg="原告是李四",
# metadata={"case_id": "case_002"}
# )
# store.add_conv(
# session_id="case_002",
# user_msg="借款金额是多少?",
# ai_msg="借款金额是60万元",
# metadata={"case_id": "case_002"}
# )
# store.add_conv(
# session_id="case_002",
# user_msg="李四借款金额是多少?",
# ai_msg="李四借款金额是60万元",
# metadata={"case_id": "case_002"}
# )
# store.add_conv(
# session_id="case_002",
# user_msg="被告是谁?",
# ai_msg="李四",
# metadata={"case_id": "case_002"}
# )
# store.add_conv(
# session_id="case_002",
# user_msg="被告要还多少钱?",
# ai_msg="李四还60万",
# metadata={"case_id": "case_002"}
# )
# relevant = store.search_conv(
# session_id="case_002",
# query="被告需要还多少钱?",
# limit=3
# )
# print("相关历史对话:")
# for conv in relevant:
# print(f" Q: {conv['user']}")
# print(f" A: {conv['ai']}")
# print(f" 相关度: {conv['score']:.3f}\n")

0
src/rag/__init__.py Normal file
View File

27
src/rag/chunker.py Normal file
View File

@@ -0,0 +1,27 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import List, Dict
class TextChunker:
def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 50
):
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", "", "?", "!", ".", ",", "", ";", "", "", " ", ""]
)
def chunk(self, text: str) -> List[Dict]:
chunks = self.splitter.split_text(text)
return [
{
"text": chunk,
"chunk_id": i,
"total_chunks": len(chunks)
}
for i, chunk in enumerate(chunks)
]

38
src/rag/embeddings.py Normal file
View File

@@ -0,0 +1,38 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from openai import OpenAI
import httpx
from src.config import settings
from typing import List
class EmbeddingClient:
def __init__(self):
http_client = httpx.Client(
proxy="http://127.0.0.1:10808",
verify=False
)
self.client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url,
http_client=http_client
)
def embed(
self,
texts: List[str],
model: str = None
) -> List[List[float]]:
response = self.client.embeddings.create(
model=model or settings.embedding_model,
input=texts
)
return [item.embedding for item in response.data]
embedding_client = EmbeddingClient()

93
src/rag/loader.py Normal file
View File

@@ -0,0 +1,93 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from typing import List, Dict
from src.rag.store import QdrantVecStore
from src.rag.chunker import TextChunker
from src.rag.embeddings import embedding_client
class DocLoader:
def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 50
):
self.store = QdrantVecStore()
self.chunker = TextChunker(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
def loader_text(
self,
text: str,
case_id: str,
source: str = "unknown"
) -> List[str]:
chunks = self.chunker.chunk(text)
texts = [chunk["text"] for chunk in chunks]
vectors = embedding_client.embed(texts)
metadatas = [
{
"case_id": case_id,
"source": source,
"chunk_id": chunk["chunk_id"],
"total_chunks": chunk["total_chunks"]
}
for chunk in chunks
]
ids = self.store.add_documents(
texts=texts,
vectors=vectors,
metadatas=metadatas
)
print(f" 案件 {case_id} 入库完成: {len(ids)} 个块")
return ids
def loader_file(
self,
file_path: str,
case_id: str
) -> List[str]:
path = Path(file_path)
with open(path, "r", encoding="utf-8") as f:
text = f.read()
return self.loader_text(
text=text,
case_id=case_id,
source=path.name
)
# if __name__ == "__main__":
# loader = DocLoader()
# test_text = """
# 案件编号: case_001
# 原告: 张三
# 被告: 李四
# 原告诉讼请求:
# 1. 判令被告归还借款本金50万元
# 2. 支付利息3万元
# 3. 支付违约金2万元
# """
# loader.loader_text(
# text=test_text,
# case_id="case_001",
# source="complaint.txt"
# )
# print("已入库")

88
src/rag/retriever.py Normal file
View File

@@ -0,0 +1,88 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from typing import List, Dict, Optional
from src.rag.store import QdrantVecStore
from src.rag.embeddings import embedding_client
class Retriever:
def __init__(self):
self.store = QdrantVecStore()
def retrieve(
self,
query: str,
case_id: Optional[str] = None,
top_k: int = 3
) -> List[Dict]:
query_vector = embedding_client.embed([query])[0]
results = self.store.search(
query_vector=query_vector,
case_id=case_id,
top_k=top_k
)
return results
def format_context(self, results: List[Dict]) -> str:
if not results:
return "<documents>\n<message>没有找到相关文档。</message>\n</documents>"
context_parts = ["<documents>"]
max_score = max(r["score"] for r in results)
context_parts.append(
f"<summary>共检索到 {len(results)} 个相关文档片段,按相关度从高到低排序</summary>\n"
)
for i, result in enumerate(results, 1):
source = result["metadata"].get("source", "未知来源")
text = result["text"]
score = result["score"]
if score >= 0.85:
relevance = "高度相关"
elif score >= 0.70:
relevance = "相关"
else:
relevance = "可能相关"
context_parts.append(
f'<document index="{i}" source="{source}" relevance="{relevance}" score="{score:.3f}">\n'
f'{text}\n'
f'</document>'
)
context_parts.append("</documents>")
return "\n\n".join(context_parts)
# if __name__ == "__main__":
# retriever = Retriever()
# question = "被告李四的主要辩称理由是什么?"
# print(f" 检索问题: {question}\n")
# results = retriever.retrieve(
# query=question,
# case_id="case001",
# top_k=3
# )
# print(f" 检索到 {len(results)} 个相关片段:\n")
# for i, result in enumerate(results, 1):
# print(f"--- 结果 {i} (score: {result['score']:.3f}) ---")
# print(f"内容: {result['text'][:100]}...")
# print(f"来源: {result['metadata']['source']}")
# print()
# context = retriever.format_context(results)
# print(f" 格式化上下文:\n{context}")

160
src/rag/store.py Normal file
View File

@@ -0,0 +1,160 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
Filter,
FieldCondition,
MatchValue
)
from src.config import settings
class QdrantVecStore:
def __init__(self, collection_name: str = None):
self.client = QdrantClient(
host=settings.qdrant_host,
port=settings.qdrant_port
)
self.collection_name = collection_name or settings.qdrant_collection
self._ensure_collection()
def _ensure_collection(self):
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if not exists:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=1536,
distance=Distance.COSINE
)
)
print(f" collection_name: {self.collection_name}")
def add_documents(
self,
texts: list[str],
vectors: list[list[float]],
metadatas: list[dict]
) -> list[str]:
import uuid
points = []
ids = []
for txt, vec, meta in zip(texts, vectors, metadatas):
doc_id = str(uuid.uuid4())
ids.append(doc_id)
points.append(
PointStruct(
id=doc_id,
vector=vec,
payload={
"text": txt,
"case_id": meta.get("case_id"),
"source": meta.get("source", "unknown"),
"page": meta.get("page"),
"chunk_id": meta.get("chunk_id"),
**meta # 其他数据
}
)
)
self.client.upsert(
collection_name=self.collection_name,
points=points
)
print(f" 添加 {len(points)} 个文档片")
return ids
def search(
self,
query_vector: list[float],
case_id: str = None,
top_k: int = 3
) -> list[dict]:
query_filter = None
if case_id:
query_filter = Filter(
must=[
FieldCondition(
key="case_id",
match=MatchValue(value=case_id)
)
]
)
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
query_filter=query_filter,
limit=top_k,
with_payload=True
)
points = results.points if hasattr(results, 'points') else results
return [
{
"text": r.payload["text"],
"score": r.score,
"metadata": {
k: v for k, v in r.payload.items()
if k != "text"
}
}
for r in points
]
def delete_by_case(self, case_id: str):
self.client.delete(
collection_name=self.collection_name,
points_selector=Filter(
must=[
FieldCondition(
key="case_id",
match=MatchValue(value=case_id)
)
]
)
)
print(f" 删除案件 {case_id} 的所有文档")
def get_stats(self) -> dict:
info = self.client.get_collection(self.collection_name)
coll = getattr(info, 'result', info)
total_docs = getattr(coll, "points_count", 0) or getattr(coll, "vectors_count", 0)
vector_size = 1536
config = getattr(coll, "config", None)
if config:
params = getattr(config, "params", None)
if params:
vectors_config = getattr(params, "vectors", None)
if vectors_config and hasattr(vectors_config, "size"):
vector_size = vectors_config.size
return {
"collection_name": self.collection_name,
"total_documents": int(total_docs),
"vector_size": int(vector_size),
"total_points": getattr(coll, "points_count", None),
"vectors_count": getattr(coll, "vectors_count", None),
"indexed_vectors_count": getattr(coll, "indexed_vectors_count", None)
}
# if __name__ == "__main__":
# store = QdrantVecStore()
# print(f" 连接到 Qdrant: {settings.qdrant_host}:{settings.qdrant_port}")
# print(f" 集合: {store.collection_name}")
# stats = store.get_stats()
# print(f" 统计: {stats}")

270
src/tools/case_tools.py Normal file
View File

@@ -0,0 +1,270 @@
from datetime import datetime, timedelta
from typing import Dict, Optional, List
import re
from qdrant_client.models import Filter, FieldCondition, MatchValue
from src.rag.store import QdrantVecStore
from src.mcp.mem import ConvoStore
class CaseTools:
@staticmethod
def get_case_metadata(case_id: str) -> Dict:
if not case_id or not case_id.strip():
return {
"error": "未提供案件ID",
"description": "请提供有效的案件ID例如: 'case_001'"
}
store = QdrantVecStore()
collections = store.client.get_collections()
col_names = [col.name for col in collections.collections]
if not col_names:
return {
"error": f"Qdrant链接失败",
"description": "请检查Qdrant是否启动"
}
collection_name = getattr(store, "collection_name", None) or getattr(store, "collection_name", None)
if collection_name not in col_names:
return {
"error": f"未找到集合: {collection_name}",
"description": "请确认案件文档是否已入库"
}
results = store.client.scroll(
collection_name = collection_name,
scroll_filter=Filter(
must=[
FieldCondition(
key="case_id",
match=MatchValue(value=case_id)
)
]
),
limit=100,
with_payload=True,
with_vectors=False
)
if not results[0]:
return {
"error": f"未找到案件 {case_id} 的文档",
"description": "请确认案件ID是否正确或该案件是否已录入系统"
}
target = results[0] if results else []
if not target:
return {
"error": f"未找到案件 {case_id} 的文档",
"case_id": case_id,
"description": "该案件不存在或尚未录入系统"
}
all_source = set()
chunk_cnt = len(target)
for point in target:
payload = point.payload if hasattr(point, 'payload') else {}
if payload and "source" in payload:
all_source.add(payload["source"])
return {
"case_id": case_id,
"doc_cnt": len(all_source),
"chunk_cnt": chunk_cnt,
"sources": sorted(list(all_source)),
"status":"已入库",
"description": f"案件 {case_id} 共包含 {len(all_source)} 个文档,共 {chunk_cnt} 个块。"
}
@staticmethod
def build_rag_prompt(
query: str,
context: str,
session_id: Optional[str] = None,
max_history: int = 5,
system_prompt: Optional[str] = None
) -> Dict:
messages = []
sour_his = None
conv_his = []
def_sys_prompt = """你是一个专业的法律案件分析 AI 助手。
<role>
你的职责:
1. 基于检索到的案件文档片段,准确回答用户的问题
2. 结合对话历史理解上下文(如代词指代、前置问题等)
3. 提供客观、专业、有依据的法律分析
4. 明确标注信息来源,便于用户核实
</role>
<guidelines>
核心原则:
- 忠于文档内容: 仅基于提供的文档回答,不得编造或推测未在文档中体现的信息
- 明确引用来源: 回答时使用 [文档X] 标注信息来源X 为文档索引号)
- 区分确定与不确定: 如果文档信息不足以明确回答,应诚实说明
- 结合对话历史: 理解代词指代、关联前置问题、维持话题连贯性
- 专业法律表达: 使用准确的法律术语,避免模糊或口语化表述
回答格式:
1. 直接回答问题(引用文档)
2. 如需要,提供补充说明或法律解释
3. 如果信息不足,明确说明缺失的部分
</guidelines>"""
messages.append({
"role": "system",
"content": system_prompt or def_sys_prompt
})
if session_id:
store = ConvoStore()
rel_convs = store.search_conv(
session_id=session_id,
query=query,
limit=max_history
)
if rel_convs and len(rel_convs) > 0:
rel_convs.sort(key=lambda x: x.get('timestep',''))
for conv in rel_convs:
conv_his.append({
"role": "user",
"content": conv["user"]
})
conv_his.append({
"role": "ai",
"content": conv["ai"]
})
sour_his = "relevant"
else:
rec_convs = store.get_all_conv(
session_id=session_id,
limit=max_history
)
if rec_convs and len(rec_convs) > 0:
rec_convs = rec_convs[:max_history]
rec_convs.reverse()
for conv in rec_convs:
conv_his.append({
"role": "user",
"content": conv["user"]
})
conv_his.append({
"role": "ai",
"content": conv["ai"]
})
sour_his = "recent"
if conv_his and len(conv_his) > 0:
messages.extend(conv_his)
user_msg = []
if conv_his and len(conv_his) > 0:
user_msg.append("""<instruction>
请结合上述对话历史理解当前问题。注意识别代词指代和话题延续。
</instruction>
""")
user_msg.append(f"""<context>
{context}
</context>
<question>
{query}
</question>
请根据上述文档回答问题。记得引用具体的文档编号。""")
messages.append({
"role": "user",
"content": "\n".join(user_msg)
})
his_len = len(conv_his) // 2 if conv_his else 0
return {
"messages": messages,
"prompt_stats": {
"total_messages": len(messages),
"has_context": bool(context),
"has_history": bool(conv_his),
"history_length": his_len,
"history_source": sour_his,
"session_id": session_id
},
"info": {
"query": query,
"session_id": session_id,
"history_used": his_len
}
}
@staticmethod
def gen_timeline(documents: List[Dict]) -> Dict:
events = []
all_text = "\n".join([doc.get("text", "") for doc in documents])
date_pattern = r'(\d{4})[年/-](\d{1,2})[月/-](\d{1,2})日?'
for doc in documents:
text = doc.get("text", "")
source = doc.get("metadata", {}).get("source", "unknown")
matches = re.finditer(date_pattern, text)
for match in matches:
year, month, day = match.groups()
date_str = f"{year}-{month.zfill(2)}-{day.zfill(2)}"
start = max(0, match.start() - 30)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
event_type = "事件"
if "借款" in context or "出借" in context:
event_type = "借款发生"
elif "到期" in context:
event_type = "借款到期"
elif "起诉" in context or "立案" in context:
event_type = "提起诉讼"
elif "开庭" in context or "审理" in context:
event_type = "开庭审理"
elif "判决" in context:
event_type = "作出判决"
events.append({
"date": date_str,
"event": event_type,
"context": context,
"source": source
})
events.sort(key=lambda x: x["date"])
duration_days = 0
earliest = None
latest = None
if events:
earliest = events[0]["date"]
latest = events[-1]["date"]
try:
start_date = datetime.strptime(earliest, "%Y-%m-%d")
end_date = datetime.strptime(latest, "%Y-%m-%d")
duration_days = (end_date - start_date).days
except:
pass
return {
"timeline": events[:20],
"total_events": len(events),
"duration_days": duration_days,
"earliest": earliest,
"latest": latest,
"summary": f"案件时间跨度{duration_days}天,从{earliest}{latest}" if earliest and latest else "未找到时间信息"
}
tools = CaseTools()

23
start_claude.ps1 Normal file
View File

@@ -0,0 +1,23 @@
$env:ANTHROPIC_BASE_URL = "https://api.moonshot.cn/anthropic"
$env:MOONSHOT_API_KEY = "sk-joeLiNi2DzTWHr06tdFOdx3A8eKg98eowwShkTlusowyP4vS"
$env:ANTHROPIC_AUTH_TOKEN = $env:MOONSHOT_API_KEY
$env:ANTHROPIC_MODEL = "kimi-k2-thinking-turbo"
$env:ANTHROPIC_DEFAULT_OPUS_MODEL = "kimi-k2-thinking-turbo"
$env:ANTHROPIC_DEFAULT_SONNET_MODEL = "kimi-k2-thinking-turbo"
$env:ANTHROPIC_DEFAULT_HAIKU_MODEL = "kimi-k2-thinking-turbo"
$env:CLAUDE_CODE_SUBAGENT_MODEL = "kimi-k2-thinking-turbo"
Write-Host " 环境变量已设置" -ForegroundColor Green
Write-Host " Base URL: $env:ANTHROPIC_BASE_URL" -ForegroundColor Cyan
Write-Host " Model: $env:ANTHROPIC_MODEL" -ForegroundColor Cyan
$claudePath = "C:\Users\ADMIN\AppData\Local\AnthropicClaude\claude.exe"
if (Test-Path $claudePath) {
Write-Host " 正在启动 Claude Desktop..." -ForegroundColor Yellow
Start-Process $claudePath
} else {
Write-Host " 找不到 Claude.exe请检查安装路径" -ForegroundColor Red
}

2
tests/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""tests package"""

202
tests/test_mcp.py Normal file
View File

@@ -0,0 +1,202 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import json
from src.mcp.mem import ConvoStore
from src.tools.case_tools import CaseTools
def test_conv_store():
store = ConvoStore()
test_session = "CASE_SESSION_001"
store.add_conv(
session_id=test_session,
user_msg="这个案件的原告是谁?",
ai_msg="根据案件文档,原告是张三。",
metadata={"case_id": "CASE_001"}
)
store.add_conv(
session_id=test_session,
user_msg="借款金额是多少?",
ai_msg="借款本金为50万元约定年利率10%",
metadata={"case_id": "CASE_001"}
)
store.add_conv(
session_id=test_session,
user_msg="被告的答辩意见是什么?",
ai_msg="被告称该笔款项为合伙投资款,非借款,并称借条系伪造。",
metadata={"case_id": "CASE_001"}
)
print(f" 对话存储成功")
print(f" 会话ID: {test_session}")
print(f" 对话轮数: 3")
return True
def test_conv_search():
store = ConvoStore()
test_session = "CASE_SESSION_001"
queries = [
"原告信息",
"借了多少钱",
"被告怎么说"
]
for query in queries:
results = store.search_conv(
session_id=test_session,
query=query,
limit=2
)
print(f"\n 查询: {query}")
print(f" 结果数: {len(results)}")
if results:
print(f" 最高相关度: {results[0]['score']:.3f}")
print(f" 匹配对话: Q: {results[0]['user'][:30]}...")
assert len(results) > 0
print(f"\n 检索成功")
return True
def test_conv_all():
store = ConvoStore()
test_session = "CASE_SESSION_001"
all_convs = store.get_all_conv(
session_id=test_session,
limit=100
)
print(f" 会话获取成功")
print(f" 会话ID: {test_session}")
print(f" 对话总数: {len(all_convs)}")
if all_convs:
print(f"\n 最近对话:")
for i, conv in enumerate(all_convs[:2], 1):
print(f" {i}. Q: {conv['user_msg'][:40]}...")
print(f" A: {conv['ai_msg'][:40]}...")
assert len(all_convs) >= 3
return True
def test_case_tools_metadata():
tools = CaseTools()
result = tools.get_case_metadata("CASE_001")
print(f" 查询完成")
print(f" 案件ID: CASE_001")
if "error" in result:
print(f" 状态: {result['error']}")
print(f" 说明: {result.get('description', '')}")
else:
print(f" 文档数: {result.get('doc_cnt', 0)}")
print(f" 片段数: {result.get('chunk_cnt', 0)}")
print(f" 状态: {result.get('status', '')}")
assert isinstance(result, dict)
return True
def test_build_rag_prompt():
tools = CaseTools()
mock_context = """<documents>
<document index="1" source="complaint.txt">
原告: 张三
被告: 李四
借款金额: 50万元
</document>
<document index="2" source="defense.txt">
被告答辩: 该笔款项为合伙投资款
</document>
</documents>"""
result = tools.build_rag_prompt(
query="原告和被告分别是谁?",
context=mock_context,
session_id="CASE_SESSION_001",
max_history=2
)
print(f" prompt构建成功")
print(f" 消息总数: {result['prompt_stats']['total_messages']}")
print(f" 包含上下文: {result['prompt_stats']['has_context']}")
print(f" 包含历史: {result['prompt_stats']['has_history']}")
print(f" 历史来源: {result['prompt_stats']['history_source']}")
print(f" 历史轮数: {result['prompt_stats']['history_length']}")
assert len(result['messages']) > 0
assert result['prompt_stats']['has_context']
return True
def test_cleanup():
store = ConvoStore()
test_session = "CASE_SESSION_001"
all_convs = store.get_all_conv(session_id=test_session, limit=1000)
print(f" 清理前对话数: {len(all_convs)}")
count = store.delete_by_session(test_session)
print(f" 清理完成")
print(f" 测试会话: {test_session}")
print(f" 删除对话数: {count}")
remaining = store.get_all_conv(session_id=test_session, limit=1000)
print(f" 清理后对话数: {len(remaining)}")
assert len(remaining) == 0
return True
def run_all_tests():
print("\n" + "="*60)
print("MCP 测试")
print("="*60)
tests = [
("对话存储", test_conv_store),
("对话检索", test_conv_search),
("获取全部对话", test_conv_all),
("获取案件元数据", test_case_tools_metadata),
("构建 RAG Prompt", test_build_rag_prompt),
("清理测试数据", test_cleanup)
]
passed = 0
failed = 0
for name, test_func in tests:
if test_func():
passed += 1
else:
failed += 1
print(f"\n 测试失败: {name}")
print("\n" + "="*60)
print(f"结果: {passed} 通过, {failed} 失败")
print("="*60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)

182
tests/test_rag.py Normal file
View File

@@ -0,0 +1,182 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.rag.loader import DocLoader
from src.rag.retriever import Retriever
from src.rag.store import QdrantVecStore
from src.rag.embeddings import embedding_client
def test_connect():
store = QdrantVecStore()
stats = store.get_stats()
print(f" Qdrant连接成功")
print(f" 集合: {stats['collection_name']}")
print(f" 文档数: {stats['total_documents']}")
print(f" 向量维度: {stats['vector_size']}")
assert stats['collection_name'] is not None
assert stats['total_documents'] >= 0
return True
def test_embed():
texts = ["测试文本1", "测试文本2"]
vectors = embedding_client.embed(texts)
print(f" txt1,txt2 embed成功")
print(f" 输入文本数: {len(texts)}")
print(f" 输出向量数: {len(vectors)}")
print(f" 向量维度: {len(vectors[0])}")
assert len(vectors) == len(texts)
assert len(vectors[0]) == 1536
return True
def test_doc_load():
loader = DocLoader(chunk_size=200, chunk_overlap=50)
test_text = """
案件编号: CASE_001
案件名称: 测试案件 - 借款合同纠纷
原告: 张三
被告: 李四
案情简介:
原告张三于2023年1月向被告李四出借人民币50万元约定年利率10%
借款期限为1年。到期后被告未按约定归还本金及利息。
原告诉讼请求:
1. 判令被告归还借款本金50万元
2. 支付利息5万元
3. 承担本案诉讼费用
被告答辩:
该笔款项为合伙投资款,非借款。原告提供的借条系伪造。
"""
ids = loader.loader_text(
text=test_text,
case_id="CASE_001",
source="test_complaint.txt"
)
print(f" doc入库成功")
print(f" 案件ID: CASE_001")
print(f" 文档片段数: {len(ids)}")
print(f" 片段ID示例: {ids[0][:8]}...")
assert len(ids) > 0
return True
def test_doc_search():
retriever = Retriever()
queries = [
"原告是谁?",
"借款金额是多少?",
"被告的答辩意见是什么?"
]
for query in queries:
results = retriever.retrieve(
query=query,
case_id="CASE_001",
top_k=2
)
print(f"\n 查询: {query}")
print(f" 结果数: {len(results)}")
if results:
print(f" 最高相似度: {results[0]['score']:.3f}")
print(f" 片段预览: {results[0]['text'][:50]}...")
assert len(results) > 0
assert results[0]['score'] > 0
print(f"\n 检索成功")
return True
def test_text_format():
retriever = Retriever()
results = retriever.retrieve(
query="案件的原被告信息",
case_id="CASE_001",
top_k=3
)
formatted = retriever.format_context(results)
print(f" 格式化成功")
print(f" 原始结果数: {len(results)}")
print(f" 格式化长度: {len(formatted)} 字符")
print(f"\n 格式化预览:")
print(f" {formatted[:200]}...")
assert "<documents>" in formatted
assert "</documents>" in formatted
assert len(formatted) > 0
return True
def test_cleanup():
store = QdrantVecStore()
count = store.delete_by_case("CASE_001")
print(f" 清理完成")
if count is not None:
print(f" 删除文档数: {count}")
assert count >= 0
else:
print(f" 删除操作已执行(未返回计数)")
stats = store.get_stats()
print(f" 剩余文档数: {stats['total_documents']}")
return True
def run_all_tests():
print("\n" + "="*60)
print("RAG测试")
print("="*60)
tests = [
("Qdrant 连接", test_connect),
("嵌入客户端", test_embed),
("文档入库", test_doc_load),
("文档检索", test_doc_search),
("上下文清除", test_text_format),
("清理数据", test_cleanup)
]
passed = 0
failed = 0
for name, test_func in tests:
if test_func():
passed += 1
else:
print(f"\n 测试失败: {name}")
failed += 1
print("\n" + "="*60)
print(f"结果: {passed} 通过, {failed} 失败")
print("="*60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)