commit 0bb8b4e13b364d98d5eb6795ecbddd0acfab1278 Author: jianghaiying Date: Fri Nov 28 15:06:54 2025 +0800 aiqs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6c668a9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +venv/ +.venv/ +__pycache__/ +*.pyc +.env +.env.local +*.log +dist/ +build/ +*.egg-info/ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info +data/ +.vscode/ +.claude/ + +# Virtual environments +.venv/ +aiqs/ + diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..b8af630 --- /dev/null +++ b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from aiqs!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8820b7f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "aiqs" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "openai==2.7.1", + "python-dotenv>=1.2.1", + "pydantic>=2.12.4", + "pydantic-settings>=2.11.0", +] diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..e5abe56 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,9 @@ +openai>=1.0.0 +pydantic>=2.0.0 +pydantic-settings>=2.0.0 +httpx>=0.24.0 + +qdrant-client>=1.7.0 +tiktoken>=0.5.0 + +pillow>=10.0.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..f7173b1 --- /dev/null +++ b/src/config.py @@ -0,0 +1,33 @@ +from pathlib import Path +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +PROJECT_ROOT = Path(__file__).parent.parent + + +class Settings(BaseSettings): + openai_api_key: str = Field(default="", description="key") + openai_base_url: str = Field( + default="https://api.openai.com/v1", + description="url", + ) + chat_model: str = Field( + default="gpt-4o-mini" + ) + embedding_model: str = Field( + default="text-embedding-3-small" + ) + qdrant_host: str = Field(default="localhost") + qdrant_port: int = Field(default=6333) + qdrant_collection: str = Field( + default="legal_documents" + ) + model_config = SettingsConfigDict( + env_file=PROJECT_ROOT / ".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + +settings = Settings() diff --git a/src/mcp/case_server.py b/src/mcp/case_server.py new file mode 100644 index 0000000..3224fd1 --- /dev/null +++ b/src/mcp/case_server.py @@ -0,0 +1,334 @@ +import asyncio +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp import types +import json + +from src.rag.retriever import Retriever +from src.tools.case_tools import tools +from src.mcp.mem import ConvoStore + +app = Server("legal-case-server") +retriever = Retriever() +conv = ConvoStore() + +@app.list_tools() +async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="search_case_documents", + description="在法律案件文档中检索相关信息。用于回答关于案件事实、当事人信息、诉讼请求、判决结果等问题。", + inputSchema={ + "type": "object", + "properties": { + "case_id": { + "type": "string", + "description": "案件ID,如 'case_001'" + }, + "query": { + "type": "string", + "description": "检索问题,如 '原告的诉讼请求是什么?'" + }, + "top_k": { + "type": "integer", + "description": "返回结果数量,默认3", + "default": 3 + } + }, + "required": ["case_id", "query"] + } + ), + types.Tool( + name="get_case_metadata", + description="获取案件的基本信息", + inputSchema={ + "type": "object", + "properties": { + "case_id": { + "type": "string", + "description": "案件ID" + } + }, + "required": ["case_id"] + } + ), + types.Tool( + name="build_rag_prompt", + description="构建标准化的 RAG Prompt,整合检索上下文和对话历史。返回可直接用于 LLM 的 messages 数组。", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "用户问题" + }, + "context": { + "type": "string", + "description": "检索到的文档上下文(XML 格式)" + }, + "conversation_history": { + "type": "array", + "description": "对话历史(可选)", + "items": { + "type": "object", + "properties": { + "role": {"type": "string"}, + "content": {"type": "string"} + } + } + }, + "system_prompt": { + "type": "string", + "description": "自定义系统提示词(可选)" + } + }, + "required": ["query", "context"] + } + ), + types.Tool( + name="gen_timeline", + description="从文档中生成案件时间线,自动提取并排序时间信息", + inputSchema={ + "type": "object", + "properties": { + "documents": { + "type": "array", + "description": "检索结果文档列表", + "items": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "metadata": {"type": "object"} + } + } + } + }, + "required": ["documents"] + } + ), + types.Tool( + name="store_conservation", + description="存储对话到会话记录中", + inputSchema={ + "type": "object", + "properties": { + "session_id": {"type": "string", "description": "会话ID"}, + "user_message": {"type": "string", "description": "用户消息"}, + "ai_message": {"type": "string", "description": "AI回复"}, + "metadata": {"type": "object", "description": "额外元数据"} + }, + "required": ["session_id", "user_message", "ai_message"] + } + ), + types.Tool( + name="search_conversation", + description="检索会话中的相关对话记录", + inputSchema={ + "type": "object", + "properties": { + "session_id": {"type": "string"}, + "query": {"type": "string", "description": "搜索查询"}, + "limit": {"type": "integer","default": 5} + }, + "required": ["session_id", "query"] + } + ), + types.Tool( + name="get_all_conversation", + description="获取所有会话记录", + inputSchema={ + "type": "object", + "properties": { + "session_id": {"type": "string", "description": "会话ID"}, + "limit": {"type": "integer","default": 100} + }, + "required": ["session_id"] + } + ) + ] + +@app.list_resources() +async def list_resources() -> list[types.Resource]: + return [] + +@app.list_prompts() +async def list_prompts() -> list[types.Prompt]: + return [ + types.Prompt( + name="analyze_case", + description="分析法律案件的标准流程", + arguments=[ + types.PromptArgument( + name="case_id", + description="案件ID", + required=True + ) + ] + ) + ] + + +@app.call_tool() +async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: + if name == "search_case_documents": + case_id = arguments["case_id"] + query = arguments["query"] + top_k = arguments.get("top_k", 3) + + results = retriever.retrieve( + query=query, + case_id=case_id, + top_k=top_k + ) + if not results: + response = { + "status": "no_results", + "message": f"在案件 {case_id} 中未找到与 '{query}' 相关的文档" + } + else: + context = retriever.format_context(results) + + response = { + "status": "success", + "case_id": case_id, + "query": query, + "results_count": len(results), + "context": context, + "documents": [ + { + "index": i + 1, + "text": r["text"][:300], + "score": round(r["score"], 3), + "source": r["metadata"].get("source", "unknown") + } + for i, r in enumerate(results) + ] + } + + return [types.TextContent( + type="text", + text=json.dumps(response, ensure_ascii=False, indent=2) + )] + + elif name == "get_case_metadata": + result = tools.get_case_metadata(**arguments) + return [types.TextContent( + type="text", + text=json.dumps(result, ensure_ascii=False, indent=2) + )] + + elif name == "build_rag_prompt": + result = tools.build_rag_prompt(**arguments) + return [types.TextContent( + type="text", + text=json.dumps(result, ensure_ascii=False, indent=2) + )] + elif name == "gen_timeline": + result = tools.gen_timeline(**arguments) + return [types.TextContent( + type="text", + text=json.dumps(result, ensure_ascii=False, indent=2) + )] + elif name == "store_conservation": + session_id = arguments.get("session_id") + user_msg = arguments.get("user_msg") + ai_msg = arguments.get("ai_msg") + metadata = arguments.get("metadata", {}) + + conv.add_conv( + session_id=session_id, + user_msg=user_msg, + ai_msg=ai_msg, + metadata=metadata + ) + + return [types.TextContent( + type="text", + text=f"对话以存储到会话{session_id}中。" + )] + + elif name == "search_conversation": + session_id = arguments.get("session_id") + query = arguments.get("query") + limit = arguments.get("limit", 5) + + results = conv.search_conv( + session_id=session_id, + query=query, + limit=limit + ) + + formatted = "\n\n".join([ + f"**Q**: {r['user']}\n**A**: {r['assistant']}\n(相关度: {r['score']:.2f})" + for r in results + ]) + + return [types.TextContent( + type="text", + text=formatted or "未找到相关对话记录。" + )] + + elif name == "get_all_conversation": + session_id = arguments.get("session_id") + limit = arguments.get("limit", 100) + + results = conv.get_all_conv( + session_id=session_id, + limit=limit + ) + + formatted = "\n\n".join([ + f"**Q**: {r['user']}\n**A**: {r['assistant']}" + for r in results + ]) + + return [types.TextContent( + type="text", + text=formatted or "该会话暂无对话记录。" + )] + + + raise ValueError(f"未知工具: {name}") + +@app.get_prompt() +async def get_prompt(name: str, arguments: dict) -> types.GetPromptResult: + if name == "analyze_case": + case_id = arguments["case_id"] + + return types.GetPromptResult( + description=f"分析案件 {case_id}", + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent( + type="text", + text=f"请分析案件 {case_id},包括:1. 案情摘要 2. 争议焦点 3. 法律依据" + ) + ) + ] + ) + +@app.read_resource() +async def read_resource(uri: str) -> str: + mapping = { + "case://case_001/summary": Path("d:/workspace/aiqs/data/case_001_summary.txt") + } + path = mapping.get(uri) + if path and path.exists(): + return path.read_text(encoding="utf-8") + return f"未列出的资源或不存在: {uri}" + +async def main(): + async with stdio_server() as (read_stream, write_stream): + await app.run( + read_stream, + write_stream, + app.create_initialization_options() + ) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/mcp/mem.py b/src/mcp/mem.py new file mode 100644 index 0000000..521e1d2 --- /dev/null +++ b/src/mcp/mem.py @@ -0,0 +1,239 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from qdrant_client import QdrantClient +from qdrant_client.models import Distance, VectorParams, PointStruct +from typing import List, Dict, Optional +import uuid +from datetime import datetime +from src.rag.embeddings import embedding_client + + +class ConvoStore: + + def __init__( + self, + collection_name: str = "conversations", + qdrant_url: str = "http://localhost:6333" + ): + self.client = QdrantClient(url=qdrant_url) + self.collection_name = collection_name + + if not collection_name in [c.name for c in self.client.get_collections().collections]: + self.client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=1536, + distance=Distance.COSINE + ) + ) + else: + self.client.get_collection(collection_name) + + def add_conv( + self, + session_id: str, + user_msg: str, + ai_msg: str, + metadata: Optional[Dict] = None + ): + if not session_id or not user_msg or not ai_msg: + return None + + # 用户问+ai答 为一条记录(还是分开?) + conv_text = f"用户: {user_msg}\nAI: {ai_msg}" + vector = embedding_client.embed([conv_text])[0] + + payload = { + "session_id": session_id, + "user_msg": user_msg, + "ai_msg": ai_msg, + "timestamp": datetime.now().isoformat() + } + if metadata: + payload.update(metadata) + + point_id = str(uuid.uuid4()) + self.client.upsert( + collection_name=self.collection_name, + points=[ + PointStruct( + id=point_id, + vector=vector, + payload=payload + ) + ] + ) + return point_id + + def search_conv( + self, + session_id: str, + query: str, + limit: int = 5 + ) -> List[Dict]: + if not session_id or not query: + return [] + + query_vector = embedding_client.embed([query])[0] + + results = self.client.query_points( + collection_name=self.collection_name, + query=query_vector, + query_filter={ + "must": [ + {"key": "session_id", "match": {"value": session_id}} + ] + }, + limit=limit, + with_payload=True + ) + points = results.points if hasattr(results, "points") else [] + if not points: + return [] + convs = [] + for point in points: + payload = point.payload if hasattr(point, "payload") else {} + if not payload: + continue + convs.append({ + "user": payload.get("user_msg",""), + "ai": payload.get("ai_msg",""), + "timestamp": payload.get("timestamp",""), + "score": point.score if hasattr(point, "score") else 0.0 + }) + + return convs + + def get_all_conv( + self, + session_id: str, + limit: int = 100 + ) -> List[Dict]: + if not session_id: + return [] + + results = self.client.scroll( + collection_name=self.collection_name, + scroll_filter={ + "must": [ + {"key": "session_id", + "match": {"value": session_id}} + ] + }, + limit=limit, + with_payload=True, + with_vectors=False + ) + points = results[0] if results and len(results) > 0 else [] + if not points: + return [] + + convs = [] + for point in points: + payload = point.payload if hasattr(point, "payload") else {} + if not payload: + continue + convs.append({ + "user_msg": payload.get("user_msg",""), + "ai_msg": payload.get("ai_msg",""), + "timestamp": payload.get("timestamp","") + }) + convs.sort(key=lambda x: x.get("timestamp", ""), reverse=True) + return convs + + def delete_by_session(self, session_id: str) -> int: + if not session_id: + return 0 + + results = self.client.scroll( + collection_name=self.collection_name, + scroll_filter={ + "must": [ + {"key": "session_id", + "match": {"value": session_id}} + ] + }, + limit=10000, + with_payload=False, + with_vectors=False + ) + + points = results[0] if results and len(results) > 0 else [] + + if not points: + return 0 + point_ids = [point.id for point in points] + + self.client.delete( + collection_name=self.collection_name, + points_selector=point_ids + ) + + return len(point_ids) + +# # def qdrant_example(): +# if __name__ == "__main__": +# store = ConvoStore() + +# store.add_conv( +# session_id="case_001", +# user_msg="原告是谁?", +# ai_msg="原告是张三", +# metadata={"case_id": "case_001"} +# ) + +# store.add_conv( +# session_id="case_001", +# user_msg="借款金额是多少?", +# ai_msg="借款金额是50万元", +# metadata={"case_id": "case_001"} +# ) + +# store.add_conv( +# session_id="case_002", +# user_msg="原告是谁?", +# ai_msg="原告是李四", +# metadata={"case_id": "case_002"} +# ) + +# store.add_conv( +# session_id="case_002", +# user_msg="借款金额是多少?", +# ai_msg="借款金额是60万元", +# metadata={"case_id": "case_002"} +# ) + +# store.add_conv( +# session_id="case_002", +# user_msg="李四借款金额是多少?", +# ai_msg="李四借款金额是60万元", +# metadata={"case_id": "case_002"} +# ) + +# store.add_conv( +# session_id="case_002", +# user_msg="被告是谁?", +# ai_msg="李四", +# metadata={"case_id": "case_002"} +# ) + +# store.add_conv( +# session_id="case_002", +# user_msg="被告要还多少钱?", +# ai_msg="李四还60万", +# metadata={"case_id": "case_002"} +# ) + +# relevant = store.search_conv( +# session_id="case_002", +# query="被告需要还多少钱?", +# limit=3 +# ) + +# print("相关历史对话:") +# for conv in relevant: +# print(f" Q: {conv['user']}") +# print(f" A: {conv['ai']}") +# print(f" 相关度: {conv['score']:.3f}\n") \ No newline at end of file diff --git a/src/rag/__init__.py b/src/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rag/chunker.py b/src/rag/chunker.py new file mode 100644 index 0000000..7803ab5 --- /dev/null +++ b/src/rag/chunker.py @@ -0,0 +1,27 @@ +from langchain_text_splitters import RecursiveCharacterTextSplitter +from typing import List, Dict + + +class TextChunker: + def __init__( + self, + chunk_size: int = 500, + chunk_overlap: int = 50 + ): + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separators=["\n\n", "\n", "。", "?", "!", ".", ",", "!", ";", "?", ";", " ", ""] + ) + + def chunk(self, text: str) -> List[Dict]: + chunks = self.splitter.split_text(text) + + return [ + { + "text": chunk, + "chunk_id": i, + "total_chunks": len(chunks) + } + for i, chunk in enumerate(chunks) + ] \ No newline at end of file diff --git a/src/rag/embeddings.py b/src/rag/embeddings.py new file mode 100644 index 0000000..45e4433 --- /dev/null +++ b/src/rag/embeddings.py @@ -0,0 +1,38 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from openai import OpenAI +import httpx +from src.config import settings +from typing import List + + +class EmbeddingClient: + def __init__(self): + http_client = httpx.Client( + proxy="http://127.0.0.1:10808", + verify=False + ) + + self.client = OpenAI( + api_key=settings.openai_api_key, + base_url=settings.openai_base_url, + http_client=http_client + ) + + def embed( + self, + texts: List[str], + model: str = None + ) -> List[List[float]]: + + response = self.client.embeddings.create( + model=model or settings.embedding_model, + input=texts + ) + + return [item.embedding for item in response.data] + + +embedding_client = EmbeddingClient() diff --git a/src/rag/loader.py b/src/rag/loader.py new file mode 100644 index 0000000..704de63 --- /dev/null +++ b/src/rag/loader.py @@ -0,0 +1,93 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from typing import List, Dict +from src.rag.store import QdrantVecStore +from src.rag.chunker import TextChunker +from src.rag.embeddings import embedding_client + + +class DocLoader: + + def __init__( + self, + chunk_size: int = 500, + chunk_overlap: int = 50 + ): + self.store = QdrantVecStore() + self.chunker = TextChunker( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + + def loader_text( + self, + text: str, + case_id: str, + source: str = "unknown" + ) -> List[str]: + + chunks = self.chunker.chunk(text) + + texts = [chunk["text"] for chunk in chunks] + vectors = embedding_client.embed(texts) + + metadatas = [ + { + "case_id": case_id, + "source": source, + "chunk_id": chunk["chunk_id"], + "total_chunks": chunk["total_chunks"] + } + for chunk in chunks + ] + + ids = self.store.add_documents( + texts=texts, + vectors=vectors, + metadatas=metadatas + ) + + print(f" 案件 {case_id} 入库完成: {len(ids)} 个块") + return ids + + def loader_file( + self, + file_path: str, + case_id: str + ) -> List[str]: + + path = Path(file_path) + + with open(path, "r", encoding="utf-8") as f: + text = f.read() + + return self.loader_text( + text=text, + case_id=case_id, + source=path.name + ) + + +# if __name__ == "__main__": +# loader = DocLoader() + +# test_text = """ +# 案件编号: case_001 +# 原告: 张三 +# 被告: 李四 + +# 原告诉讼请求: +# 1. 判令被告归还借款本金50万元 +# 2. 支付利息3万元 +# 3. 支付违约金2万元 +# """ + +# loader.loader_text( +# text=test_text, +# case_id="case_001", +# source="complaint.txt" +# ) + +# print("已入库") diff --git a/src/rag/retriever.py b/src/rag/retriever.py new file mode 100644 index 0000000..2ffe75a --- /dev/null +++ b/src/rag/retriever.py @@ -0,0 +1,88 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from typing import List, Dict, Optional +from src.rag.store import QdrantVecStore +from src.rag.embeddings import embedding_client + + +class Retriever: + def __init__(self): + self.store = QdrantVecStore() + + def retrieve( + self, + query: str, + case_id: Optional[str] = None, + top_k: int = 3 + ) -> List[Dict]: + query_vector = embedding_client.embed([query])[0] + + results = self.store.search( + query_vector=query_vector, + case_id=case_id, + top_k=top_k + ) + + return results + + def format_context(self, results: List[Dict]) -> str: + + if not results: + return "\n没有找到相关文档。\n" + + context_parts = [""] + + + max_score = max(r["score"] for r in results) + context_parts.append( + f"共检索到 {len(results)} 个相关文档片段,按相关度从高到低排序\n" + ) + + for i, result in enumerate(results, 1): + source = result["metadata"].get("source", "未知来源") + text = result["text"] + score = result["score"] + + if score >= 0.85: + relevance = "高度相关" + elif score >= 0.70: + relevance = "相关" + else: + relevance = "可能相关" + + context_parts.append( + f'\n' + f'{text}\n' + f'' + ) + + context_parts.append("") + + return "\n\n".join(context_parts) + + +# if __name__ == "__main__": +# retriever = Retriever() + +# question = "被告李四的主要辩称理由是什么?" + +# print(f" 检索问题: {question}\n") + +# results = retriever.retrieve( +# query=question, +# case_id="case001", +# top_k=3 +# ) + +# print(f" 检索到 {len(results)} 个相关片段:\n") + +# for i, result in enumerate(results, 1): +# print(f"--- 结果 {i} (score: {result['score']:.3f}) ---") +# print(f"内容: {result['text'][:100]}...") +# print(f"来源: {result['metadata']['source']}") +# print() + +# context = retriever.format_context(results) +# print(f" 格式化上下文:\n{context}") \ No newline at end of file diff --git a/src/rag/store.py b/src/rag/store.py new file mode 100644 index 0000000..53c87ea --- /dev/null +++ b/src/rag/store.py @@ -0,0 +1,160 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from qdrant_client import QdrantClient +from qdrant_client.models import ( + Distance, + VectorParams, + PointStruct, + Filter, + FieldCondition, + MatchValue +) +from src.config import settings + + +class QdrantVecStore: + def __init__(self, collection_name: str = None): + self.client = QdrantClient( + host=settings.qdrant_host, + port=settings.qdrant_port + ) + self.collection_name = collection_name or settings.qdrant_collection + self._ensure_collection() + + def _ensure_collection(self): + collections = self.client.get_collections().collections + exists = any(c.name == self.collection_name for c in collections) + + if not exists: + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams( + size=1536, + distance=Distance.COSINE + ) + ) + print(f" collection_name: {self.collection_name}") + + def add_documents( + self, + texts: list[str], + vectors: list[list[float]], + metadatas: list[dict] + ) -> list[str]: + import uuid + + points = [] + ids = [] + + for txt, vec, meta in zip(texts, vectors, metadatas): + doc_id = str(uuid.uuid4()) + ids.append(doc_id) + + points.append( + PointStruct( + id=doc_id, + vector=vec, + payload={ + "text": txt, + "case_id": meta.get("case_id"), + "source": meta.get("source", "unknown"), + "page": meta.get("page"), + "chunk_id": meta.get("chunk_id"), + **meta # 其他数据 + } + ) + ) + + self.client.upsert( + collection_name=self.collection_name, + points=points + ) + + print(f" 添加 {len(points)} 个文档片") + return ids + + def search( + self, + query_vector: list[float], + case_id: str = None, + top_k: int = 3 + ) -> list[dict]: + query_filter = None + if case_id: + query_filter = Filter( + must=[ + FieldCondition( + key="case_id", + match=MatchValue(value=case_id) + ) + ] + ) + + results = self.client.query_points( + collection_name=self.collection_name, + query=query_vector, + query_filter=query_filter, + limit=top_k, + with_payload=True + ) + points = results.points if hasattr(results, 'points') else results + return [ + { + "text": r.payload["text"], + "score": r.score, + "metadata": { + k: v for k, v in r.payload.items() + if k != "text" + } + } + for r in points + ] + + def delete_by_case(self, case_id: str): + self.client.delete( + collection_name=self.collection_name, + points_selector=Filter( + must=[ + FieldCondition( + key="case_id", + match=MatchValue(value=case_id) + ) + ] + ) + ) + print(f" 删除案件 {case_id} 的所有文档") + + def get_stats(self) -> dict: + info = self.client.get_collection(self.collection_name) + coll = getattr(info, 'result', info) + + total_docs = getattr(coll, "points_count", 0) or getattr(coll, "vectors_count", 0) + + vector_size = 1536 + config = getattr(coll, "config", None) + if config: + params = getattr(config, "params", None) + if params: + vectors_config = getattr(params, "vectors", None) + if vectors_config and hasattr(vectors_config, "size"): + vector_size = vectors_config.size + + return { + "collection_name": self.collection_name, + "total_documents": int(total_docs), + "vector_size": int(vector_size), + "total_points": getattr(coll, "points_count", None), + "vectors_count": getattr(coll, "vectors_count", None), + "indexed_vectors_count": getattr(coll, "indexed_vectors_count", None) + } + + +# if __name__ == "__main__": +# store = QdrantVecStore() +# print(f" 连接到 Qdrant: {settings.qdrant_host}:{settings.qdrant_port}") +# print(f" 集合: {store.collection_name}") + +# stats = store.get_stats() +# print(f" 统计: {stats}") \ No newline at end of file diff --git a/src/tools/case_tools.py b/src/tools/case_tools.py new file mode 100644 index 0000000..1319a12 --- /dev/null +++ b/src/tools/case_tools.py @@ -0,0 +1,270 @@ +from datetime import datetime, timedelta +from typing import Dict, Optional, List +import re +from qdrant_client.models import Filter, FieldCondition, MatchValue +from src.rag.store import QdrantVecStore +from src.mcp.mem import ConvoStore + + +class CaseTools: + @staticmethod + def get_case_metadata(case_id: str) -> Dict: + + if not case_id or not case_id.strip(): + return { + "error": "未提供案件ID", + "description": "请提供有效的案件ID,例如: 'case_001'" + } + + store = QdrantVecStore() + + collections = store.client.get_collections() + col_names = [col.name for col in collections.collections] + if not col_names: + return { + "error": f"Qdrant链接失败", + "description": "请检查Qdrant是否启动" + } + collection_name = getattr(store, "collection_name", None) or getattr(store, "collection_name", None) + if collection_name not in col_names: + return { + "error": f"未找到集合: {collection_name}", + "description": "请确认案件文档是否已入库" + } + + results = store.client.scroll( + collection_name = collection_name, + scroll_filter=Filter( + must=[ + FieldCondition( + key="case_id", + match=MatchValue(value=case_id) + ) + ] + ), + limit=100, + with_payload=True, + with_vectors=False + ) + if not results[0]: + return { + "error": f"未找到案件 {case_id} 的文档", + "description": "请确认案件ID是否正确,或该案件是否已录入系统" + } + + target = results[0] if results else [] + if not target: + return { + "error": f"未找到案件 {case_id} 的文档", + "case_id": case_id, + "description": "该案件不存在或尚未录入系统" + } + + all_source = set() + chunk_cnt = len(target) + for point in target: + payload = point.payload if hasattr(point, 'payload') else {} + if payload and "source" in payload: + all_source.add(payload["source"]) + + return { + "case_id": case_id, + "doc_cnt": len(all_source), + "chunk_cnt": chunk_cnt, + "sources": sorted(list(all_source)), + "status":"已入库", + "description": f"案件 {case_id} 共包含 {len(all_source)} 个文档,共 {chunk_cnt} 个块。" + } + + + @staticmethod + def build_rag_prompt( + query: str, + context: str, + session_id: Optional[str] = None, + max_history: int = 5, + system_prompt: Optional[str] = None + ) -> Dict: + + messages = [] + sour_his = None + conv_his = [] + + + def_sys_prompt = """你是一个专业的法律案件分析 AI 助手。 + + +你的职责: +1. 基于检索到的案件文档片段,准确回答用户的问题 +2. 结合对话历史理解上下文(如代词指代、前置问题等) +3. 提供客观、专业、有依据的法律分析 +4. 明确标注信息来源,便于用户核实 + + + +核心原则: +- 忠于文档内容: 仅基于提供的文档回答,不得编造或推测未在文档中体现的信息 +- 明确引用来源: 回答时使用 [文档X] 标注信息来源(X 为文档索引号) +- 区分确定与不确定: 如果文档信息不足以明确回答,应诚实说明 +- 结合对话历史: 理解代词指代、关联前置问题、维持话题连贯性 +- 专业法律表达: 使用准确的法律术语,避免模糊或口语化表述 + +回答格式: +1. 直接回答问题(引用文档) +2. 如需要,提供补充说明或法律解释 +3. 如果信息不足,明确说明缺失的部分 +""" + + messages.append({ + "role": "system", + "content": system_prompt or def_sys_prompt + }) + + if session_id: + store = ConvoStore() + rel_convs = store.search_conv( + session_id=session_id, + query=query, + limit=max_history + ) + if rel_convs and len(rel_convs) > 0: + rel_convs.sort(key=lambda x: x.get('timestep','')) + for conv in rel_convs: + conv_his.append({ + "role": "user", + "content": conv["user"] + }) + conv_his.append({ + "role": "ai", + "content": conv["ai"] + }) + sour_his = "relevant" + + else: + rec_convs = store.get_all_conv( + session_id=session_id, + limit=max_history + ) + if rec_convs and len(rec_convs) > 0: + rec_convs = rec_convs[:max_history] + rec_convs.reverse() + for conv in rec_convs: + conv_his.append({ + "role": "user", + "content": conv["user"] + }) + conv_his.append({ + "role": "ai", + "content": conv["ai"] + }) + sour_his = "recent" + + if conv_his and len(conv_his) > 0: + messages.extend(conv_his) + + user_msg = [] + if conv_his and len(conv_his) > 0: + user_msg.append(""" +请结合上述对话历史理解当前问题。注意识别代词指代和话题延续。 + +""") + + user_msg.append(f""" +{context} + + + +{query} + + +请根据上述文档回答问题。记得引用具体的文档编号。""") + messages.append({ + "role": "user", + "content": "\n".join(user_msg) + }) + + his_len = len(conv_his) // 2 if conv_his else 0 + return { + "messages": messages, + "prompt_stats": { + "total_messages": len(messages), + "has_context": bool(context), + "has_history": bool(conv_his), + "history_length": his_len, + "history_source": sour_his, + "session_id": session_id + }, + "info": { + "query": query, + "session_id": session_id, + "history_used": his_len + } + } + + @staticmethod + def gen_timeline(documents: List[Dict]) -> Dict: + events = [] + all_text = "\n".join([doc.get("text", "") for doc in documents]) + + date_pattern = r'(\d{4})[年/-](\d{1,2})[月/-](\d{1,2})日?' + + for doc in documents: + text = doc.get("text", "") + source = doc.get("metadata", {}).get("source", "unknown") + + matches = re.finditer(date_pattern, text) + for match in matches: + year, month, day = match.groups() + date_str = f"{year}-{month.zfill(2)}-{day.zfill(2)}" + + start = max(0, match.start() - 30) + end = min(len(text), match.end() + 50) + context = text[start:end].strip() + + event_type = "事件" + if "借款" in context or "出借" in context: + event_type = "借款发生" + elif "到期" in context: + event_type = "借款到期" + elif "起诉" in context or "立案" in context: + event_type = "提起诉讼" + elif "开庭" in context or "审理" in context: + event_type = "开庭审理" + elif "判决" in context: + event_type = "作出判决" + + events.append({ + "date": date_str, + "event": event_type, + "context": context, + "source": source + }) + + events.sort(key=lambda x: x["date"]) + + duration_days = 0 + earliest = None + latest = None + + if events: + earliest = events[0]["date"] + latest = events[-1]["date"] + + try: + start_date = datetime.strptime(earliest, "%Y-%m-%d") + end_date = datetime.strptime(latest, "%Y-%m-%d") + duration_days = (end_date - start_date).days + except: + pass + + return { + "timeline": events[:20], + "total_events": len(events), + "duration_days": duration_days, + "earliest": earliest, + "latest": latest, + "summary": f"案件时间跨度{duration_days}天,从{earliest}到{latest}" if earliest and latest else "未找到时间信息" + } + + +tools = CaseTools() \ No newline at end of file diff --git a/start_claude.ps1 b/start_claude.ps1 new file mode 100644 index 0000000..0387bc0 --- /dev/null +++ b/start_claude.ps1 @@ -0,0 +1,23 @@ + +$env:ANTHROPIC_BASE_URL = "https://api.moonshot.cn/anthropic" +$env:MOONSHOT_API_KEY = "sk-joeLiNi2DzTWHr06tdFOdx3A8eKg98eowwShkTlusowyP4vS" +$env:ANTHROPIC_AUTH_TOKEN = $env:MOONSHOT_API_KEY + +$env:ANTHROPIC_MODEL = "kimi-k2-thinking-turbo" +$env:ANTHROPIC_DEFAULT_OPUS_MODEL = "kimi-k2-thinking-turbo" +$env:ANTHROPIC_DEFAULT_SONNET_MODEL = "kimi-k2-thinking-turbo" +$env:ANTHROPIC_DEFAULT_HAIKU_MODEL = "kimi-k2-thinking-turbo" +$env:CLAUDE_CODE_SUBAGENT_MODEL = "kimi-k2-thinking-turbo" + +Write-Host " 环境变量已设置" -ForegroundColor Green +Write-Host " Base URL: $env:ANTHROPIC_BASE_URL" -ForegroundColor Cyan +Write-Host " Model: $env:ANTHROPIC_MODEL" -ForegroundColor Cyan + +$claudePath = "C:\Users\ADMIN\AppData\Local\AnthropicClaude\claude.exe" + +if (Test-Path $claudePath) { + Write-Host " 正在启动 Claude Desktop..." -ForegroundColor Yellow + Start-Process $claudePath +} else { + Write-Host " 找不到 Claude.exe,请检查安装路径" -ForegroundColor Red +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..2e0630d --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +"""tests package""" + diff --git a/tests/test_mcp.py b/tests/test_mcp.py new file mode 100644 index 0000000..d81039c --- /dev/null +++ b/tests/test_mcp.py @@ -0,0 +1,202 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import json +from src.mcp.mem import ConvoStore +from src.tools.case_tools import CaseTools + + +def test_conv_store(): + store = ConvoStore() + + test_session = "CASE_SESSION_001" + + store.add_conv( + session_id=test_session, + user_msg="这个案件的原告是谁?", + ai_msg="根据案件文档,原告是张三。", + metadata={"case_id": "CASE_001"} + ) + + store.add_conv( + session_id=test_session, + user_msg="借款金额是多少?", + ai_msg="借款本金为50万元,约定年利率10%。", + metadata={"case_id": "CASE_001"} + ) + + store.add_conv( + session_id=test_session, + user_msg="被告的答辩意见是什么?", + ai_msg="被告称该笔款项为合伙投资款,非借款,并称借条系伪造。", + metadata={"case_id": "CASE_001"} + ) + + print(f" 对话存储成功") + print(f" 会话ID: {test_session}") + print(f" 对话轮数: 3") + return True + + +def test_conv_search(): + + store = ConvoStore() + test_session = "CASE_SESSION_001" + + queries = [ + "原告信息", + "借了多少钱", + "被告怎么说" + ] + + for query in queries: + results = store.search_conv( + session_id=test_session, + query=query, + limit=2 + ) + + print(f"\n 查询: {query}") + print(f" 结果数: {len(results)}") + + if results: + print(f" 最高相关度: {results[0]['score']:.3f}") + print(f" 匹配对话: Q: {results[0]['user'][:30]}...") + + assert len(results) > 0 + + print(f"\n 检索成功") + return True + + +def test_conv_all(): + + store = ConvoStore() + test_session = "CASE_SESSION_001" + + all_convs = store.get_all_conv( + session_id=test_session, + limit=100 + ) + + print(f" 会话获取成功") + print(f" 会话ID: {test_session}") + print(f" 对话总数: {len(all_convs)}") + + if all_convs: + print(f"\n 最近对话:") + for i, conv in enumerate(all_convs[:2], 1): + print(f" {i}. Q: {conv['user_msg'][:40]}...") + print(f" A: {conv['ai_msg'][:40]}...") + + assert len(all_convs) >= 3 + return True + +def test_case_tools_metadata(): + + tools = CaseTools() + result = tools.get_case_metadata("CASE_001") + + print(f" 查询完成") + print(f" 案件ID: CASE_001") + + if "error" in result: + print(f" 状态: {result['error']}") + print(f" 说明: {result.get('description', '')}") + else: + print(f" 文档数: {result.get('doc_cnt', 0)}") + print(f" 片段数: {result.get('chunk_cnt', 0)}") + print(f" 状态: {result.get('status', '')}") + + assert isinstance(result, dict) + return True + + +def test_build_rag_prompt(): + tools = CaseTools() + + mock_context = """ + +原告: 张三 +被告: 李四 +借款金额: 50万元 + + +被告答辩: 该笔款项为合伙投资款 + +""" + + result = tools.build_rag_prompt( + query="原告和被告分别是谁?", + context=mock_context, + session_id="CASE_SESSION_001", + max_history=2 + ) + + print(f" prompt构建成功") + print(f" 消息总数: {result['prompt_stats']['total_messages']}") + print(f" 包含上下文: {result['prompt_stats']['has_context']}") + print(f" 包含历史: {result['prompt_stats']['has_history']}") + print(f" 历史来源: {result['prompt_stats']['history_source']}") + print(f" 历史轮数: {result['prompt_stats']['history_length']}") + + assert len(result['messages']) > 0 + assert result['prompt_stats']['has_context'] + return True + +def test_cleanup(): + store = ConvoStore() + test_session = "CASE_SESSION_001" + + all_convs = store.get_all_conv(session_id=test_session, limit=1000) + print(f" 清理前对话数: {len(all_convs)}") + + count = store.delete_by_session(test_session) + + print(f" 清理完成") + print(f" 测试会话: {test_session}") + print(f" 删除对话数: {count}") + + remaining = store.get_all_conv(session_id=test_session, limit=1000) + print(f" 清理后对话数: {len(remaining)}") + + assert len(remaining) == 0 + + return True + + +def run_all_tests(): + print("\n" + "="*60) + print("MCP 测试") + print("="*60) + + tests = [ + ("对话存储", test_conv_store), + ("对话检索", test_conv_search), + ("获取全部对话", test_conv_all), + ("获取案件元数据", test_case_tools_metadata), + ("构建 RAG Prompt", test_build_rag_prompt), + ("清理测试数据", test_cleanup) + ] + + passed = 0 + failed = 0 + + for name, test_func in tests: + if test_func(): + passed += 1 + else: + failed += 1 + print(f"\n 测试失败: {name}") + + print("\n" + "="*60) + print(f"结果: {passed} 通过, {failed} 失败") + print("="*60) + + return failed == 0 + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) diff --git a/tests/test_rag.py b/tests/test_rag.py new file mode 100644 index 0000000..7606427 --- /dev/null +++ b/tests/test_rag.py @@ -0,0 +1,182 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.rag.loader import DocLoader +from src.rag.retriever import Retriever +from src.rag.store import QdrantVecStore +from src.rag.embeddings import embedding_client + + +def test_connect(): + store = QdrantVecStore() + stats = store.get_stats() + + print(f" Qdrant连接成功") + print(f" 集合: {stats['collection_name']}") + print(f" 文档数: {stats['total_documents']}") + print(f" 向量维度: {stats['vector_size']}") + + assert stats['collection_name'] is not None + assert stats['total_documents'] >= 0 + return True + + +def test_embed(): + texts = ["测试文本1", "测试文本2"] + vectors = embedding_client.embed(texts) + + print(f" txt1,txt2 embed成功") + print(f" 输入文本数: {len(texts)}") + print(f" 输出向量数: {len(vectors)}") + print(f" 向量维度: {len(vectors[0])}") + + assert len(vectors) == len(texts) + assert len(vectors[0]) == 1536 + return True + + +def test_doc_load(): + + loader = DocLoader(chunk_size=200, chunk_overlap=50) + + test_text = """ + 案件编号: CASE_001 + 案件名称: 测试案件 - 借款合同纠纷 + + 原告: 张三 + 被告: 李四 + + 案情简介: + 原告张三于2023年1月向被告李四出借人民币50万元,约定年利率10%, + 借款期限为1年。到期后被告未按约定归还本金及利息。 + + 原告诉讼请求: + 1. 判令被告归还借款本金50万元 + 2. 支付利息5万元 + 3. 承担本案诉讼费用 + + 被告答辩: + 该笔款项为合伙投资款,非借款。原告提供的借条系伪造。 + """ + + ids = loader.loader_text( + text=test_text, + case_id="CASE_001", + source="test_complaint.txt" + ) + + print(f" doc入库成功") + print(f" 案件ID: CASE_001") + print(f" 文档片段数: {len(ids)}") + print(f" 片段ID示例: {ids[0][:8]}...") + + assert len(ids) > 0 + return True + + +def test_doc_search(): + + retriever = Retriever() + + queries = [ + "原告是谁?", + "借款金额是多少?", + "被告的答辩意见是什么?" + ] + + for query in queries: + results = retriever.retrieve( + query=query, + case_id="CASE_001", + top_k=2 + ) + + print(f"\n 查询: {query}") + print(f" 结果数: {len(results)}") + + if results: + print(f" 最高相似度: {results[0]['score']:.3f}") + print(f" 片段预览: {results[0]['text'][:50]}...") + + assert len(results) > 0 + assert results[0]['score'] > 0 + + print(f"\n 检索成功") + return True + + +def test_text_format(): + retriever = Retriever() + + results = retriever.retrieve( + query="案件的原被告信息", + case_id="CASE_001", + top_k=3 + ) + + formatted = retriever.format_context(results) + + print(f" 格式化成功") + print(f" 原始结果数: {len(results)}") + print(f" 格式化长度: {len(formatted)} 字符") + print(f"\n 格式化预览:") + print(f" {formatted[:200]}...") + + assert "" in formatted + assert "" in formatted + assert len(formatted) > 0 + return True + + +def test_cleanup(): + store = QdrantVecStore() + count = store.delete_by_case("CASE_001") + + print(f" 清理完成") + if count is not None: + print(f" 删除文档数: {count}") + assert count >= 0 + else: + print(f" 删除操作已执行(未返回计数)") + + stats = store.get_stats() + print(f" 剩余文档数: {stats['total_documents']}") + + return True + + +def run_all_tests(): + print("\n" + "="*60) + print("RAG测试") + print("="*60) + + tests = [ + ("Qdrant 连接", test_connect), + ("嵌入客户端", test_embed), + ("文档入库", test_doc_load), + ("文档检索", test_doc_search), + ("上下文清除", test_text_format), + ("清理数据", test_cleanup) + ] + + passed = 0 + failed = 0 + + for name, test_func in tests: + if test_func(): + passed += 1 + else: + print(f"\n 测试失败: {name}") + failed += 1 + + print("\n" + "="*60) + print(f"结果: {passed} 通过, {failed} 失败") + print("="*60) + + return failed == 0 + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1)