aiqs
This commit is contained in:
160
src/rag/store.py
Normal file
160
src/rag/store.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
VectorParams,
|
||||
PointStruct,
|
||||
Filter,
|
||||
FieldCondition,
|
||||
MatchValue
|
||||
)
|
||||
from src.config import settings
|
||||
|
||||
|
||||
class QdrantVecStore:
|
||||
def __init__(self, collection_name: str = None):
|
||||
self.client = QdrantClient(
|
||||
host=settings.qdrant_host,
|
||||
port=settings.qdrant_port
|
||||
)
|
||||
self.collection_name = collection_name or settings.qdrant_collection
|
||||
self._ensure_collection()
|
||||
|
||||
def _ensure_collection(self):
|
||||
collections = self.client.get_collections().collections
|
||||
exists = any(c.name == self.collection_name for c in collections)
|
||||
|
||||
if not exists:
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=1536,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f" collection_name: {self.collection_name}")
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
texts: list[str],
|
||||
vectors: list[list[float]],
|
||||
metadatas: list[dict]
|
||||
) -> list[str]:
|
||||
import uuid
|
||||
|
||||
points = []
|
||||
ids = []
|
||||
|
||||
for txt, vec, meta in zip(texts, vectors, metadatas):
|
||||
doc_id = str(uuid.uuid4())
|
||||
ids.append(doc_id)
|
||||
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=doc_id,
|
||||
vector=vec,
|
||||
payload={
|
||||
"text": txt,
|
||||
"case_id": meta.get("case_id"),
|
||||
"source": meta.get("source", "unknown"),
|
||||
"page": meta.get("page"),
|
||||
"chunk_id": meta.get("chunk_id"),
|
||||
**meta # 其他数据
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=points
|
||||
)
|
||||
|
||||
print(f" 添加 {len(points)} 个文档片")
|
||||
return ids
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_vector: list[float],
|
||||
case_id: str = None,
|
||||
top_k: int = 3
|
||||
) -> list[dict]:
|
||||
query_filter = None
|
||||
if case_id:
|
||||
query_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="case_id",
|
||||
match=MatchValue(value=case_id)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
results = self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
query=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=top_k,
|
||||
with_payload=True
|
||||
)
|
||||
points = results.points if hasattr(results, 'points') else results
|
||||
return [
|
||||
{
|
||||
"text": r.payload["text"],
|
||||
"score": r.score,
|
||||
"metadata": {
|
||||
k: v for k, v in r.payload.items()
|
||||
if k != "text"
|
||||
}
|
||||
}
|
||||
for r in points
|
||||
]
|
||||
|
||||
def delete_by_case(self, case_id: str):
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="case_id",
|
||||
match=MatchValue(value=case_id)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
print(f" 删除案件 {case_id} 的所有文档")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
info = self.client.get_collection(self.collection_name)
|
||||
coll = getattr(info, 'result', info)
|
||||
|
||||
total_docs = getattr(coll, "points_count", 0) or getattr(coll, "vectors_count", 0)
|
||||
|
||||
vector_size = 1536
|
||||
config = getattr(coll, "config", None)
|
||||
if config:
|
||||
params = getattr(config, "params", None)
|
||||
if params:
|
||||
vectors_config = getattr(params, "vectors", None)
|
||||
if vectors_config and hasattr(vectors_config, "size"):
|
||||
vector_size = vectors_config.size
|
||||
|
||||
return {
|
||||
"collection_name": self.collection_name,
|
||||
"total_documents": int(total_docs),
|
||||
"vector_size": int(vector_size),
|
||||
"total_points": getattr(coll, "points_count", None),
|
||||
"vectors_count": getattr(coll, "vectors_count", None),
|
||||
"indexed_vectors_count": getattr(coll, "indexed_vectors_count", None)
|
||||
}
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# store = QdrantVecStore()
|
||||
# print(f" 连接到 Qdrant: {settings.qdrant_host}:{settings.qdrant_port}")
|
||||
# print(f" 集合: {store.collection_name}")
|
||||
|
||||
# stats = store.get_stats()
|
||||
# print(f" 统计: {stats}")
|
||||
Reference in New Issue
Block a user