fix: fix single data point addition to weaiate

This commit is contained in:
Boris Arzentar 2024-11-11 15:56:09 +01:00 committed by Leon Luithlen
parent b1b6b79ca4
commit eb5f30fcd1
2 changed files with 10 additions and 9 deletions

View file

@ -11,7 +11,6 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter") logger = logging.getLogger("WeaviateAdapter")
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
uuid: str
text: str text: str
_metadata: dict = { _metadata: dict = {
@ -89,8 +88,10 @@ class WeaviateAdapter(VectorDBInterface):
def convert_to_weaviate_data_points(data_point: DataPoint): def convert_to_weaviate_data_points(data_point: DataPoint):
vector = data_vectors[data_points.index(data_point)] vector = data_vectors[data_points.index(data_point)]
properties = data_point.model_dump() properties = data_point.model_dump()
properties["uuid"] = properties["id"]
del properties["id"] if "id" in properties:
properties["uuid"] = str(data_point.id)
del properties["id"]
return DataObject( return DataObject(
uuid = data_point.id, uuid = data_point.id,
@ -114,7 +115,7 @@ class WeaviateAdapter(VectorDBInterface):
) )
else: else:
data_point: DataObject = data_points[0] data_point: DataObject = data_points[0]
return collection.data.update( return collection.data.insert(
uuid = data_point.uuid, uuid = data_point.uuid,
vector = data_point.vector, vector = data_point.vector,
properties = data_point.properties, properties = data_point.properties,
@ -130,8 +131,8 @@ class WeaviateAdapter(VectorDBInterface):
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]): async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [ await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema( IndexSchema(
uuid = str(data_point.id), id = data_point.id,
text = getattr(data_point, data_point._metadata["index_fields"][0]), text = data_point.get_embeddable_data(),
) for data_point in data_points ) for data_point in data_points
]) ])
@ -178,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id = UUID(result.uuid), id = UUID(str(result.uuid)),
payload = result.properties, payload = result.properties,
score = float(result.metadata.score) score = 1 - float(result.metadata.score)
) for result in search_result.objects ) for result in search_result.objects
] ]

View file

@ -35,7 +35,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "quantum computer"))[0] random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)