160 lines
4.9 KiB
Python
160 lines
4.9 KiB
Python
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import (
|
|
Distance,
|
|
VectorParams,
|
|
PointStruct,
|
|
Filter,
|
|
FieldCondition,
|
|
MatchValue
|
|
)
|
|
from src.config import settings
|
|
|
|
|
|
class QdrantVecStore:
|
|
def __init__(self, collection_name: str = None):
|
|
self.client = QdrantClient(
|
|
host=settings.qdrant_host,
|
|
port=settings.qdrant_port
|
|
)
|
|
self.collection_name = collection_name or settings.qdrant_collection
|
|
self._ensure_collection()
|
|
|
|
def _ensure_collection(self):
|
|
collections = self.client.get_collections().collections
|
|
exists = any(c.name == self.collection_name for c in collections)
|
|
|
|
if not exists:
|
|
self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=VectorParams(
|
|
size=1536,
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
print(f" collection_name: {self.collection_name}")
|
|
|
|
def add_documents(
|
|
self,
|
|
texts: list[str],
|
|
vectors: list[list[float]],
|
|
metadatas: list[dict]
|
|
) -> list[str]:
|
|
import uuid
|
|
|
|
points = []
|
|
ids = []
|
|
|
|
for txt, vec, meta in zip(texts, vectors, metadatas):
|
|
doc_id = str(uuid.uuid4())
|
|
ids.append(doc_id)
|
|
|
|
points.append(
|
|
PointStruct(
|
|
id=doc_id,
|
|
vector=vec,
|
|
payload={
|
|
"text": txt,
|
|
"case_id": meta.get("case_id"),
|
|
"source": meta.get("source", "unknown"),
|
|
"page": meta.get("page"),
|
|
"chunk_id": meta.get("chunk_id"),
|
|
**meta # 其他数据
|
|
}
|
|
)
|
|
)
|
|
|
|
self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
points=points
|
|
)
|
|
|
|
print(f" 添加 {len(points)} 个文档片")
|
|
return ids
|
|
|
|
def search(
|
|
self,
|
|
query_vector: list[float],
|
|
case_id: str = None,
|
|
top_k: int = 3
|
|
) -> list[dict]:
|
|
query_filter = None
|
|
if case_id:
|
|
query_filter = Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="case_id",
|
|
match=MatchValue(value=case_id)
|
|
)
|
|
]
|
|
)
|
|
|
|
results = self.client.query_points(
|
|
collection_name=self.collection_name,
|
|
query=query_vector,
|
|
query_filter=query_filter,
|
|
limit=top_k,
|
|
with_payload=True
|
|
)
|
|
points = results.points if hasattr(results, 'points') else results
|
|
return [
|
|
{
|
|
"text": r.payload["text"],
|
|
"score": r.score,
|
|
"metadata": {
|
|
k: v for k, v in r.payload.items()
|
|
if k != "text"
|
|
}
|
|
}
|
|
for r in points
|
|
]
|
|
|
|
def delete_by_case(self, case_id: str):
|
|
self.client.delete(
|
|
collection_name=self.collection_name,
|
|
points_selector=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="case_id",
|
|
match=MatchValue(value=case_id)
|
|
)
|
|
]
|
|
)
|
|
)
|
|
print(f" 删除案件 {case_id} 的所有文档")
|
|
|
|
def get_stats(self) -> dict:
|
|
info = self.client.get_collection(self.collection_name)
|
|
coll = getattr(info, 'result', info)
|
|
|
|
total_docs = getattr(coll, "points_count", 0) or getattr(coll, "vectors_count", 0)
|
|
|
|
vector_size = 1536
|
|
config = getattr(coll, "config", None)
|
|
if config:
|
|
params = getattr(config, "params", None)
|
|
if params:
|
|
vectors_config = getattr(params, "vectors", None)
|
|
if vectors_config and hasattr(vectors_config, "size"):
|
|
vector_size = vectors_config.size
|
|
|
|
return {
|
|
"collection_name": self.collection_name,
|
|
"total_documents": int(total_docs),
|
|
"vector_size": int(vector_size),
|
|
"total_points": getattr(coll, "points_count", None),
|
|
"vectors_count": getattr(coll, "vectors_count", None),
|
|
"indexed_vectors_count": getattr(coll, "indexed_vectors_count", None)
|
|
}
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# store = QdrantVecStore()
|
|
# print(f" 连接到 Qdrant: {settings.qdrant_host}:{settings.qdrant_port}")
|
|
# print(f" 集合: {store.collection_name}")
|
|
|
|
# stats = store.get_stats()
|
|
# print(f" 统计: {stats}") |