This commit is contained in:
2025-11-28 15:06:54 +08:00
commit 0bb8b4e13b
21 changed files with 1744 additions and 0 deletions

160
src/rag/store.py Normal file
View File

@@ -0,0 +1,160 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
Filter,
FieldCondition,
MatchValue
)
from src.config import settings
class QdrantVecStore:
def __init__(self, collection_name: str = None):
self.client = QdrantClient(
host=settings.qdrant_host,
port=settings.qdrant_port
)
self.collection_name = collection_name or settings.qdrant_collection
self._ensure_collection()
def _ensure_collection(self):
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if not exists:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=1536,
distance=Distance.COSINE
)
)
print(f" collection_name: {self.collection_name}")
def add_documents(
self,
texts: list[str],
vectors: list[list[float]],
metadatas: list[dict]
) -> list[str]:
import uuid
points = []
ids = []
for txt, vec, meta in zip(texts, vectors, metadatas):
doc_id = str(uuid.uuid4())
ids.append(doc_id)
points.append(
PointStruct(
id=doc_id,
vector=vec,
payload={
"text": txt,
"case_id": meta.get("case_id"),
"source": meta.get("source", "unknown"),
"page": meta.get("page"),
"chunk_id": meta.get("chunk_id"),
**meta # 其他数据
}
)
)
self.client.upsert(
collection_name=self.collection_name,
points=points
)
print(f" 添加 {len(points)} 个文档片")
return ids
def search(
self,
query_vector: list[float],
case_id: str = None,
top_k: int = 3
) -> list[dict]:
query_filter = None
if case_id:
query_filter = Filter(
must=[
FieldCondition(
key="case_id",
match=MatchValue(value=case_id)
)
]
)
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
query_filter=query_filter,
limit=top_k,
with_payload=True
)
points = results.points if hasattr(results, 'points') else results
return [
{
"text": r.payload["text"],
"score": r.score,
"metadata": {
k: v for k, v in r.payload.items()
if k != "text"
}
}
for r in points
]
def delete_by_case(self, case_id: str):
self.client.delete(
collection_name=self.collection_name,
points_selector=Filter(
must=[
FieldCondition(
key="case_id",
match=MatchValue(value=case_id)
)
]
)
)
print(f" 删除案件 {case_id} 的所有文档")
def get_stats(self) -> dict:
info = self.client.get_collection(self.collection_name)
coll = getattr(info, 'result', info)
total_docs = getattr(coll, "points_count", 0) or getattr(coll, "vectors_count", 0)
vector_size = 1536
config = getattr(coll, "config", None)
if config:
params = getattr(config, "params", None)
if params:
vectors_config = getattr(params, "vectors", None)
if vectors_config and hasattr(vectors_config, "size"):
vector_size = vectors_config.size
return {
"collection_name": self.collection_name,
"total_documents": int(total_docs),
"vector_size": int(vector_size),
"total_points": getattr(coll, "points_count", None),
"vectors_count": getattr(coll, "vectors_count", None),
"indexed_vectors_count": getattr(coll, "indexed_vectors_count", None)
}
# if __name__ == "__main__":
# store = QdrantVecStore()
# print(f" 连接到 Qdrant: {settings.qdrant_host}:{settings.qdrant_port}")
# print(f" 集合: {store.collection_name}")
# stats = store.get_stats()
# print(f" 统计: {stats}")