feat: adds get_distances from collection method to LanceDB and PgVector
This commit is contained in:
parent
f2c0fddeb2
commit
44ac9b68b4
3 changed files with 118 additions and 20 deletions
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.files.storage import LocalStorage
|
||||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
|
from ..utils import normalize_distances
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
class IndexSchema(DataPoint):
|
class IndexSchema(DataPoint):
|
||||||
|
|
@ -141,6 +142,34 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
score = 0,
|
score = 0,
|
||||||
) for result in results.to_dict("index").values()]
|
) for result in results.to_dict("index").values()]
|
||||||
|
|
||||||
|
async def get_distances_of_collection(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_text: str = None,
|
||||||
|
query_vector: List[float] = None,
|
||||||
|
with_vector: bool = False
|
||||||
|
):
|
||||||
|
if query_text is None and query_vector is None:
|
||||||
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
if query_text and not query_vector:
|
||||||
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
|
connection = await self.get_connection()
|
||||||
|
collection = await connection.open_table(collection_name)
|
||||||
|
|
||||||
|
results = await collection.vector_search(query_vector).to_pandas()
|
||||||
|
|
||||||
|
result_values = list(results.to_dict("index").values())
|
||||||
|
|
||||||
|
normalized_values = normalize_distances(result_values)
|
||||||
|
|
||||||
|
return [ScoredResult(
|
||||||
|
id=UUID(result["id"]),
|
||||||
|
payload=result["payload"],
|
||||||
|
score=normalized_values[value_index],
|
||||||
|
) for value_index, result in enumerate(result_values)]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
@ -148,6 +177,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
query_vector: List[float] = None,
|
query_vector: List[float] = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
|
normalized: bool = True
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
@ -162,26 +192,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
result_values = list(results.to_dict("index").values())
|
result_values = list(results.to_dict("index").values())
|
||||||
|
|
||||||
min_value = 100
|
normalized_values = normalize_distances(result_values)
|
||||||
max_value = 0
|
|
||||||
|
|
||||||
for result in result_values:
|
|
||||||
value = float(result["_distance"])
|
|
||||||
if value > max_value:
|
|
||||||
max_value = value
|
|
||||||
if value < min_value:
|
|
||||||
min_value = value
|
|
||||||
|
|
||||||
normalized_values = []
|
|
||||||
min_value = min(result["_distance"] for result in result_values)
|
|
||||||
max_value = max(result["_distance"] for result in result_values)
|
|
||||||
|
|
||||||
if max_value == min_value:
|
|
||||||
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
|
||||||
normalized_values = [0 for _ in result_values]
|
|
||||||
else:
|
|
||||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
|
||||||
result_values]
|
|
||||||
|
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = UUID(result["id"]),
|
id = UUID(result["id"]),
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from .serialize_data import serialize_data
|
from .serialize_data import serialize_data
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
|
from ..utils import normalize_distances
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
from ...relational.ModelBase import Base
|
from ...relational.ModelBase import Base
|
||||||
|
|
@ -22,6 +23,19 @@ class IndexSchema(DataPoint):
|
||||||
"index_fields": ["text"]
|
"index_fields": ["text"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def singleton(class_):
|
||||||
|
# Note: Using this singleton as a decorator to a class removes
|
||||||
|
# the option to use class methods for that class
|
||||||
|
instances = {}
|
||||||
|
|
||||||
|
def getinstance(*args, **kwargs):
|
||||||
|
if class_ not in instances:
|
||||||
|
instances[class_] = class_(*args, **kwargs)
|
||||||
|
return instances[class_]
|
||||||
|
|
||||||
|
return getinstance
|
||||||
|
|
||||||
|
@singleton
|
||||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -162,6 +176,53 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
) for result in results
|
) for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def get_distances_of_collection(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_text: str = None,
|
||||||
|
query_vector: List[float] = None,
|
||||||
|
with_vector: bool = False
|
||||||
|
)-> List[ScoredResult]:
|
||||||
|
if query_text is None and query_vector is None:
|
||||||
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
if query_text and not query_vector:
|
||||||
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
|
# Get PGVectorDataPoint Table from database
|
||||||
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
||||||
|
closest_items = []
|
||||||
|
|
||||||
|
# Use async session to connect to the database
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
# Find closest vectors to query_vector
|
||||||
|
closest_items = await session.execute(
|
||||||
|
select(
|
||||||
|
PGVectorDataPoint,
|
||||||
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
|
||||||
|
"similarity"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.order_by("similarity")
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_list = []
|
||||||
|
|
||||||
|
# Extract distances and find min/max for normalization
|
||||||
|
for vector in closest_items:
|
||||||
|
# TODO: Add normalization of similarity score
|
||||||
|
vector_list.append(vector)
|
||||||
|
|
||||||
|
# Create and return ScoredResult objects
|
||||||
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id = UUID(str(row.id)),
|
||||||
|
payload = row.payload,
|
||||||
|
score = row.similarity
|
||||||
|
) for row in vector_list
|
||||||
|
]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
|
||||||
26
cognee/infrastructure/databases/vector/utils.py
Normal file
26
cognee/infrastructure/databases/vector/utils.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_distances(result_values: List[dict]) -> List[float]:
|
||||||
|
min_value = 100
|
||||||
|
max_value = 0
|
||||||
|
|
||||||
|
for result in result_values:
|
||||||
|
value = float(result["_distance"])
|
||||||
|
if value > max_value:
|
||||||
|
max_value = value
|
||||||
|
if value < min_value:
|
||||||
|
min_value = value
|
||||||
|
|
||||||
|
normalized_values = []
|
||||||
|
min_value = min(result["_distance"] for result in result_values)
|
||||||
|
max_value = max(result["_distance"] for result in result_values)
|
||||||
|
|
||||||
|
if max_value == min_value:
|
||||||
|
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
||||||
|
normalized_values = [0 for _ in result_values]
|
||||||
|
else:
|
||||||
|
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
||||||
|
result_values]
|
||||||
|
|
||||||
|
return normalized_values
|
||||||
Loading…
Add table
Reference in a new issue