Merge pull request #459 from topoteretes/pgvector-add-normalization
feat: Add normalization to PGVector search
This commit is contained in:
commit
d8bde5461a
1 changed files with 14 additions and 4 deletions
|
|
@ -14,9 +14,9 @@ from ...relational.ModelBase import Base
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..utils import normalize_distances
|
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from .serialize_data import serialize_data
|
from .serialize_data import serialize_data
|
||||||
|
from ..utils import normalize_distances
|
||||||
|
|
||||||
|
|
||||||
class IndexSchema(DataPoint):
|
class IndexSchema(DataPoint):
|
||||||
|
|
@ -247,12 +247,22 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
# Extract distances and find min/max for normalization
|
# Extract distances and find min/max for normalization
|
||||||
for vector in closest_items:
|
for vector in closest_items:
|
||||||
# TODO: Add normalization of similarity score
|
vector_list.append(
|
||||||
vector_list.append(vector)
|
{
|
||||||
|
"id": UUID(str(vector.id)),
|
||||||
|
"payload": vector.payload,
|
||||||
|
"_distance": vector.similarity,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize vector distance and add this as score information to vector_list
|
||||||
|
normalized_values = normalize_distances(vector_list)
|
||||||
|
for i in range(0, len(normalized_values)):
|
||||||
|
vector_list[i]["score"] = normalized_values[i]
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=UUID(str(row.id)), payload=row.payload, score=row.similarity)
|
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue