import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue ) from src.config import settings class QdrantVecStore: def __init__(self, collection_name: str = None): self.client = QdrantClient( host=settings.qdrant_host, port=settings.qdrant_port ) self.collection_name = collection_name or settings.qdrant_collection self._ensure_collection() def _ensure_collection(self): collections = self.client.get_collections().collections exists = any(c.name == self.collection_name for c in collections) if not exists: self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( size=1536, distance=Distance.COSINE ) ) print(f" collection_name: {self.collection_name}") def add_documents( self, texts: list[str], vectors: list[list[float]], metadatas: list[dict] ) -> list[str]: import uuid points = [] ids = [] for txt, vec, meta in zip(texts, vectors, metadatas): doc_id = str(uuid.uuid4()) ids.append(doc_id) points.append( PointStruct( id=doc_id, vector=vec, payload={ "text": txt, "case_id": meta.get("case_id"), "source": meta.get("source", "unknown"), "page": meta.get("page"), "chunk_id": meta.get("chunk_id"), **meta # 其他数据 } ) ) self.client.upsert( collection_name=self.collection_name, points=points ) print(f" 添加 {len(points)} 个文档片") return ids def search( self, query_vector: list[float], case_id: str = None, top_k: int = 3 ) -> list[dict]: query_filter = None if case_id: query_filter = Filter( must=[ FieldCondition( key="case_id", match=MatchValue(value=case_id) ) ] ) results = self.client.query_points( collection_name=self.collection_name, query=query_vector, query_filter=query_filter, limit=top_k, with_payload=True ) points = results.points if hasattr(results, 'points') else results return [ { "text": r.payload["text"], "score": r.score, "metadata": { k: v for k, v in r.payload.items() if k != "text" } } for r in points ] def delete_by_case(self, case_id: str): self.client.delete( collection_name=self.collection_name, points_selector=Filter( must=[ FieldCondition( key="case_id", match=MatchValue(value=case_id) ) ] ) ) print(f" 删除案件 {case_id} 的所有文档") def get_stats(self) -> dict: info = self.client.get_collection(self.collection_name) coll = getattr(info, 'result', info) total_docs = getattr(coll, "points_count", 0) or getattr(coll, "vectors_count", 0) vector_size = 1536 config = getattr(coll, "config", None) if config: params = getattr(config, "params", None) if params: vectors_config = getattr(params, "vectors", None) if vectors_config and hasattr(vectors_config, "size"): vector_size = vectors_config.size return { "collection_name": self.collection_name, "total_documents": int(total_docs), "vector_size": int(vector_size), "total_points": getattr(coll, "points_count", None), "vectors_count": getattr(coll, "vectors_count", None), "indexed_vectors_count": getattr(coll, "indexed_vectors_count", None) } # if __name__ == "__main__": # store = QdrantVecStore() # print(f" 连接到 Qdrant: {settings.qdrant_host}:{settings.qdrant_port}") # print(f" 集合: {store.collection_name}") # stats = store.get_stats() # print(f" 统计: {stats}")