search update (#81)

* search update

* update string literals
This commit is contained in:
Preston Rasmussen 2024-09-04 10:05:45 -04:00 committed by GitHub
parent 2b6adb5279
commit e56a599a72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 212 additions and 88 deletions

View file

@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
await client.add_episode_bulk(episodes) await client.add_episode_bulk(episodes)
asyncio.run(main(False)) asyncio.run(main(True))

View file

@ -321,11 +321,11 @@ class Graphiti:
await asyncio.gather( await asyncio.gather(
*[ *[
get_relevant_edges( get_relevant_edges(
[edge],
self.driver, self.driver,
RELEVANT_SCHEMA_LIMIT, [edge],
edge.source_node_uuid, edge.source_node_uuid,
edge.target_node_uuid, edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
) )
for edge in extracted_edges for edge in extracted_edges
] ]

View file

@ -83,7 +83,7 @@ async def hybrid_search(
nodes.extend(await get_mentioned_nodes(driver, episodes)) nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods: if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges) text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges)
search_results.append(text_search) search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods: if SearchMethod.cosine_similarity in config.search_methods:
@ -95,7 +95,7 @@ async def hybrid_search(
) )
similarity_search = await edge_similarity_search( similarity_search = await edge_similarity_search(
driver, search_vector, 2 * config.num_edges driver, search_vector, None, None, 2 * config.num_edges
) )
search_results.append(similarity_search) search_results.append(similarity_search)

View file

@ -1,11 +1,11 @@
import asyncio import asyncio
import logging import logging
import re import re
import typing
from collections import defaultdict from collections import defaultdict
from time import time from time import time
from typing import Any
from neo4j import AsyncDriver from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.helpers import parse_db_date from graphiti_core.helpers import parse_db_date
@ -71,7 +71,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
node_ids=node_ids, node_ids=node_ids,
) )
context: dict[str, typing.Any] = {} context: dict[str, Any] = {}
for record in records: for record in records:
n_uuid = record['source_node_uuid'] n_uuid = record['source_node_uuid']
@ -98,13 +98,12 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search( async def edge_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
records, _, _ = await driver.execute_query( query = Query("""
"""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
@ -121,7 +120,68 @@ async def edge_similarity_search(
r.valid_at AS valid_at, r.valid_at AS valid_at,
r.invalid_at AS invalid_at r.invalid_at AS invalid_at
ORDER BY score DESC ORDER BY score DESC
""", """)
if source_node_uuid is None and target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif source_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector, search_vector=search_vector,
source_uuid=source_node_uuid, source_uuid=source_node_uuid,
target_uuid=target_node_uuid, target_uuid=target_node_uuid,
@ -161,6 +221,7 @@ async def entity_similarity_search(
RETURN RETURN
n.uuid As uuid, n.uuid As uuid,
n.name AS name, n.name AS name,
n.name_embeddings AS name_embedding,
n.created_at AS created_at, n.created_at AS created_at,
n.summary AS summary n.summary AS summary
ORDER BY score DESC ORDER BY score DESC
@ -175,6 +236,7 @@ async def entity_similarity_search(
EntityNode( EntityNode(
uuid=record['uuid'], uuid=record['uuid'],
name=record['name'], name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'], labels=['Entity'],
created_at=record['created_at'].to_native(), created_at=record['created_at'].to_native(),
summary=record['summary'], summary=record['summary'],
@ -193,8 +255,9 @@ async def entity_fulltext_search(
""" """
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN RETURN
node.uuid As uuid, node.uuid AS uuid,
node.name AS name, node.name AS name,
node.name_embeddings AS name_embedding,
node.created_at AS created_at, node.created_at AS created_at,
node.summary AS summary node.summary AS summary
ORDER BY score DESC ORDER BY score DESC
@ -210,6 +273,7 @@ async def entity_fulltext_search(
EntityNode( EntityNode(
uuid=record['uuid'], uuid=record['uuid'],
name=record['name'], name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'], labels=['Entity'],
created_at=record['created_at'].to_native(), created_at=record['created_at'].to_native(),
summary=record['summary'], summary=record['summary'],
@ -222,15 +286,12 @@ async def entity_fulltext_search(
async def edge_fulltext_search( async def edge_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# fulltext search over facts # fulltext search over facts
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' cypher_query = Query("""
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
@ -247,7 +308,70 @@ async def edge_fulltext_search(
r.valid_at AS valid_at, r.valid_at AS valid_at,
r.invalid_at AS invalid_at r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit ORDER BY score DESC LIMIT $limit
""", """)
if source_node_uuid is None and target_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
elif source_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
elif target_node_uuid is None:
cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""")
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
cypher_query,
query=fuzzy_query, query=fuzzy_query,
source_uuid=source_node_uuid, source_uuid=source_node_uuid,
target_uuid=target_node_uuid, target_uuid=target_node_uuid,
@ -286,7 +410,7 @@ async def hybrid_node_search(
Perform a hybrid search for nodes using both text queries and embeddings. Perform a hybrid search for nodes using both text queries and embeddings.
This method combines fulltext search and vector similarity search to find This method combines fulltext search and vector similarity search to find
relevant nodes in the graph database. It uses an rrf reranker. relevant nodes in the graph database. It uses a rrf reranker.
Parameters Parameters
---------- ----------
@ -379,11 +503,11 @@ async def get_relevant_nodes(
async def get_relevant_edges( async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver, driver: AsyncDriver,
edges: list[EntityEdge],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
source_node_uuid: str = '*',
target_node_uuid: str = '*',
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()
relevant_edges: list[EntityEdge] = [] relevant_edges: list[EntityEdge] = []
@ -392,13 +516,13 @@ async def get_relevant_edges(
results = await asyncio.gather( results = await asyncio.gather(
*[ *[
edge_similarity_search( edge_similarity_search(
driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
) )
for edge in edges for edge in edges
if edge.fact_embedding is not None if edge.fact_embedding is not None
], ],
*[ *[
edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid) edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
for edge in edges for edge in edges
], ],
) )
@ -440,7 +564,7 @@ async def node_distance_reranker(
scores: dict[str, float] = {} scores: dict[str, float] = {}
for uuid in sorted_uuids: for uuid in sorted_uuids:
# Find shortest path to center node # Find the shortest path to center node
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)

View file

@ -158,7 +158,7 @@ async def dedupe_edges_bulk(
relevant_edges_chunks: list[list[EntityEdge]] = list( relevant_edges_chunks: list[list[EntityEdge]] = list(
await asyncio.gather( await asyncio.gather(
*[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks] *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
) )
) )

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "graphiti-core" name = "graphiti-core"
version = "0.2.0" version = "0.2.1"
description = "A temporal graph building library" description = "A temporal graph building library"
authors = [ authors = [
"Paul Paliychuk <paul@getzep.com>", "Paul Paliychuk <paul@getzep.com>",