feat: adds get_distances from collection method to LanceDB and PgVector
This commit is contained in:
parent
f2c0fddeb2
commit
44ac9b68b4
3 changed files with 118 additions and 20 deletions
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.files.storage import LocalStorage
|
|||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..utils import normalize_distances
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
|
|
@ -141,6 +142,34 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
score = 0,
|
||||
) for result in results.to_dict("index").values()]
|
||||
|
||||
async def get_distances_of_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
with_vector: bool = False
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
results = await collection.vector_search(query_vector).to_pandas()
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [ScoredResult(
|
||||
id=UUID(result["id"]),
|
||||
payload=result["payload"],
|
||||
score=normalized_values[value_index],
|
||||
) for value_index, result in enumerate(result_values)]
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
|
|
@ -148,6 +177,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
query_vector: List[float] = None,
|
||||
limit: int = 5,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
|
@ -162,26 +192,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
min_value = 100
|
||||
max_value = 0
|
||||
|
||||
for result in result_values:
|
||||
value = float(result["_distance"])
|
||||
if value > max_value:
|
||||
max_value = value
|
||||
if value < min_value:
|
||||
min_value = value
|
||||
|
||||
normalized_values = []
|
||||
min_value = min(result["_distance"] for result in result_values)
|
||||
max_value = max(result["_distance"] for result in result_values)
|
||||
|
||||
if max_value == min_value:
|
||||
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
||||
normalized_values = [0 for _ in result_values]
|
||||
else:
|
||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
||||
result_values]
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [ScoredResult(
|
||||
id = UUID(result["id"]),
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from cognee.infrastructure.engine import DataPoint
|
|||
from .serialize_data import serialize_data
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..utils import normalize_distances
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from ...relational.ModelBase import Base
|
||||
|
|
@ -22,6 +23,19 @@ class IndexSchema(DataPoint):
|
|||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
def singleton(class_):
|
||||
# Note: Using this singleton as a decorator to a class removes
|
||||
# the option to use class methods for that class
|
||||
instances = {}
|
||||
|
||||
def getinstance(*args, **kwargs):
|
||||
if class_ not in instances:
|
||||
instances[class_] = class_(*args, **kwargs)
|
||||
return instances[class_]
|
||||
|
||||
return getinstance
|
||||
|
||||
@singleton
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
||||
def __init__(
|
||||
|
|
@ -162,6 +176,53 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
) for result in results
|
||||
]
|
||||
|
||||
async def get_distances_of_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
with_vector: bool = False
|
||||
)-> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
closest_items = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(
|
||||
select(
|
||||
PGVectorDataPoint,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
|
||||
"similarity"
|
||||
),
|
||||
)
|
||||
.order_by("similarity")
|
||||
)
|
||||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items:
|
||||
# TODO: Add normalization of similarity score
|
||||
vector_list.append(vector)
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(
|
||||
id = UUID(str(row.id)),
|
||||
payload = row.payload,
|
||||
score = row.similarity
|
||||
) for row in vector_list
|
||||
]
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
|
|
|
|||
26
cognee/infrastructure/databases/vector/utils.py
Normal file
26
cognee/infrastructure/databases/vector/utils.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from typing import List
|
||||
|
||||
|
||||
def normalize_distances(result_values: List[dict]) -> List[float]:
|
||||
min_value = 100
|
||||
max_value = 0
|
||||
|
||||
for result in result_values:
|
||||
value = float(result["_distance"])
|
||||
if value > max_value:
|
||||
max_value = value
|
||||
if value < min_value:
|
||||
min_value = value
|
||||
|
||||
normalized_values = []
|
||||
min_value = min(result["_distance"] for result in result_values)
|
||||
max_value = max(result["_distance"] for result in result_values)
|
||||
|
||||
if max_value == min_value:
|
||||
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
||||
normalized_values = [0 for _ in result_values]
|
||||
else:
|
||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
||||
result_values]
|
||||
|
||||
return normalized_values
|
||||
Loading…
Add table
Reference in a new issue