Skip to content

Commit

Permalink
feat: add label filtering to qdrant client.
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb committed Jan 9, 2025
1 parent 2f95418 commit 3294018
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
1 change: 1 addition & 0 deletions vectordb_bench/backend/clients/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
index.delete(delete_all=True, namespace=namespace)

self._metadata_key = "meta"
self._scalar_id_field = "id"
self._scalar_label_field = "label"

@classmethod
Expand Down
43 changes: 39 additions & 4 deletions vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import logging
import time
from contextlib import contextmanager
from vectordb_bench.backend.filter import Filter, FilterType

from ..api import VectorDB, DBCaseConfig
from qdrant_client.http.models import (
CollectionStatus,
VectorParams,
PayloadSchemaType,
Batch,
Filter,
QdrantFilter,
FieldCondition,
Range,
)
Expand All @@ -22,13 +23,19 @@


class QdrantCloud(VectorDB):
supported_filter_types: list[FilterType] = [
FilterType.NonFilter,
FilterType.Int,
FilterType.Label,
]
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
collection_name: str = "QdrantCloudCollection",
drop_old: bool = False,
with_scalar_labels: bool = False,
**kwargs,
):
"""Initialize wrapper around the QdrantCloud vector database."""
Expand All @@ -40,11 +47,14 @@ def __init__(
self._vector_field = "vector"

tmp_client = QdrantClient(**self.db_config)
self.with_scalar_labels = with_scalar_labels
if drop_old:
log.info(f"QdrantCloud client drop_old collection: {self.collection_name}")
tmp_client.delete_collection(self.collection_name)
self._create_collection(dim, tmp_client)
tmp_client = None
self._scalar_id_field = "id"
self._scalar_label_field = "label"

@contextmanager
def init(self) -> None:
Expand Down Expand Up @@ -105,6 +115,7 @@ def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
labels_data: list[str] = None,
**kwargs,
) -> (int, Exception):
"""Insert embeddings into Milvus. should call self.init() first"""
Expand Down Expand Up @@ -138,10 +149,10 @@ def search_embedding(
Should call self.init() first.
"""
assert self.qdrant_client is not None

f = None
condition = self.condition
f = self.condition
if filters:
f = Filter(
f = QdrantFilter(
must=[FieldCondition(
key = self._primary_field,
range = Range(
Expand All @@ -160,3 +171,27 @@ def search_embedding(

ret = [result.id for result in res[0]]
return ret

def prepare_filter(self, filter: Filter):
if filter.type == FilterType.NonFilter:
self.condition = None
elif filter.type == FilterType.Int:
self.condition = QdrantFilter(
must=[
FieldCondition(
key=self._scalar_id_field,
range=Range(gte=filter.int_value),
),
]
)
elif filter.type == FilterType.Label:
self.condition = QdrantFilter(
must=[
FieldCondition(
key=self._scalar_label_field,
match={"value": filter.label_value},
),
]
)
else:
raise ValueError(f"Not support Filter for Pinecone - {filter}")

0 comments on commit 3294018

Please sign in to comment.