Group id fix (#152)
* node distance and group_ids fixed * get all with no group_id passed * push * push * remove comments * mypy * mypy ids * please mypy * trust * last one
This commit is contained in:
parent
cfeb58daba
commit
794b705664
11 changed files with 93 additions and 110 deletions
|
|
@ -63,28 +63,27 @@ async def main(use_bulk: bool = True):
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
|
|
||||||
if not use_bulk:
|
if not use_bulk:
|
||||||
for i, message in enumerate(messages[3:4]):
|
for i, message in enumerate(messages[3:14]):
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
reference_time=message.actual_timestamp,
|
reference_time=message.actual_timestamp,
|
||||||
source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
group_id='1',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# build communities
|
# build communities
|
||||||
await client.build_communities()
|
await client.build_communities()
|
||||||
|
|
||||||
# add additional messages to update communities
|
# add additional messages to update communities
|
||||||
# for i, message in enumerate(messages[14:20]):
|
for i, message in enumerate(messages[14:20]):
|
||||||
# await client.add_episode(
|
await client.add_episode(
|
||||||
# name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
# episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
# reference_time=message.actual_timestamp,
|
reference_time=message.actual_timestamp,
|
||||||
# source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
# group_id='1',
|
group_id='1',
|
||||||
# update_communities=True,
|
update_communities=True,
|
||||||
# )
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Edge(BaseModel, ABC):
|
class Edge(BaseModel, ABC):
|
||||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
group_id: str | None = Field(description='partition of the graph')
|
group_id: str = Field(description='partition of the graph')
|
||||||
source_node_uuid: str
|
source_node_uuid: str
|
||||||
target_node_uuid: str
|
target_node_uuid: str
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
@ -131,7 +131,7 @@ class EpisodicEdge(Edge):
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||||
|
|
@ -270,7 +270,7 @@ class EntityEdge(Edge):
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
|
|
@ -360,7 +360,7 @@ class CommunityEdge(Edge):
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||||
|
|
|
||||||
|
|
@ -197,7 +197,7 @@ class Graphiti:
|
||||||
self,
|
self,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -233,7 +233,7 @@ class Graphiti:
|
||||||
source_description: str,
|
source_description: str,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
source: EpisodeType = EpisodeType.message,
|
source: EpisodeType = EpisodeType.message,
|
||||||
group_id: str | None = None,
|
group_id: str = '',
|
||||||
uuid: str | None = None,
|
uuid: str | None = None,
|
||||||
update_communities: bool = False,
|
update_communities: bool = False,
|
||||||
):
|
):
|
||||||
|
|
@ -446,7 +446,7 @@ class Graphiti:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
|
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
|
||||||
"""
|
"""
|
||||||
Process multiple episodes in bulk and update the graph.
|
Process multiple episodes in bulk and update the graph.
|
||||||
|
|
||||||
|
|
@ -577,7 +577,7 @@ class Graphiti:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
num_results=DEFAULT_SEARCH_LIMIT,
|
num_results=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -633,7 +633,7 @@ class Graphiti:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
return await search(
|
return await search(
|
||||||
|
|
@ -644,7 +644,7 @@ class Graphiti:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = DEFAULT_SEARCH_LIMIT,
|
limit: int = DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_MODEL = 'gpt-4o-2024-08-06'
|
DEFAULT_MODEL = 'gpt-4o-mini'
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(LLMClient):
|
class OpenAIClient(LLMClient):
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class EpisodeType(Enum):
|
||||||
class Node(BaseModel, ABC):
|
class Node(BaseModel, ABC):
|
||||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
name: str = Field(description='name of the node')
|
name: str = Field(description='name of the node')
|
||||||
group_id: str | None = Field(description='partition of the graph')
|
group_id: str = Field(description='partition of the graph')
|
||||||
labels: list[str] = Field(default_factory=list)
|
labels: list[str] = Field(default_factory=list)
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
||||||
|
|
||||||
|
|
@ -186,7 +186,7 @@ class EpisodicNode(Node):
|
||||||
return episodes
|
return episodes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
||||||
|
|
@ -281,7 +281,7 @@ class EntityNode(Node):
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
||||||
|
|
@ -374,7 +374,7 @@ class CommunityNode(Node):
|
||||||
return communities
|
return communities
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
@ -56,7 +57,7 @@ async def search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder,
|
embedder,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str | None] | None,
|
group_ids: list[str] | None,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
|
|
@ -103,7 +104,7 @@ async def edge_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder,
|
embedder,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str | None] | None,
|
group_ids: list[str] | None,
|
||||||
config: EdgeSearchConfig,
|
config: EdgeSearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
|
@ -140,14 +141,21 @@ async def edge_search(
|
||||||
if center_node_uuid is None:
|
if center_node_uuid is None:
|
||||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||||
|
|
||||||
source_to_edge_uuid_map = {
|
# use rrf as a preliminary sort
|
||||||
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
|
sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
|
||||||
}
|
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
|
||||||
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
|
|
||||||
|
# node distance reranking
|
||||||
|
source_to_edge_uuid_map = defaultdict(list)
|
||||||
|
for edge in sorted_results:
|
||||||
|
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
|
||||||
|
|
||||||
|
source_uuids = [edge.source_node_uuid for edge in sorted_results]
|
||||||
|
|
||||||
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
||||||
|
|
||||||
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
|
for node_uuid in reranked_node_uuids:
|
||||||
|
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
|
||||||
|
|
||||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
|
@ -161,7 +169,7 @@ async def node_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder,
|
embedder,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str | None] | None,
|
group_ids: list[str] | None,
|
||||||
config: NodeSearchConfig,
|
config: NodeSearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
|
|
@ -198,7 +206,9 @@ async def node_search(
|
||||||
elif config.reranker == NodeReranker.node_distance:
|
elif config.reranker == NodeReranker.node_distance:
|
||||||
if center_node_uuid is None:
|
if center_node_uuid is None:
|
||||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||||
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
|
reranked_uuids = await node_distance_reranker(
|
||||||
|
driver, rrf(search_result_uuids), center_node_uuid
|
||||||
|
)
|
||||||
|
|
||||||
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
|
@ -209,7 +219,7 @@ async def community_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder,
|
embedder,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str | None] | None,
|
group_ids: list[str] | None,
|
||||||
config: CommunitySearchConfig,
|
config: CommunitySearchConfig,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ async def edge_fulltext_search(
|
||||||
query: str,
|
query: str,
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
|
|
@ -95,10 +95,7 @@ async def edge_fulltext_search(
|
||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN n.group_id IS NULL
|
|
||||||
ELSE n.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -120,10 +117,7 @@ async def edge_fulltext_search(
|
||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
||||||
ELSE r.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -144,10 +138,7 @@ async def edge_fulltext_search(
|
||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
||||||
ELSE r.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -168,10 +159,7 @@ async def edge_fulltext_search(
|
||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
||||||
ELSE r.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -209,7 +197,7 @@ async def edge_similarity_search(
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
|
|
@ -217,10 +205,7 @@ async def edge_similarity_search(
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
||||||
ELSE r.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -242,10 +227,7 @@ async def edge_similarity_search(
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
||||||
ELSE r.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -266,10 +248,7 @@ async def edge_similarity_search(
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
||||||
ELSE r.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -290,10 +269,7 @@ async def edge_similarity_search(
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
||||||
ELSE r.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -327,7 +303,7 @@ async def edge_similarity_search(
|
||||||
async def node_fulltext_search(
|
async def node_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
|
|
@ -336,10 +312,7 @@ async def node_fulltext_search(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WHERE CASE
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
WHEN $group_ids IS NULL THEN n.group_id IS NULL
|
|
||||||
ELSE n.group_id IN $group_ids
|
|
||||||
END
|
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid AS uuid,
|
n.uuid AS uuid,
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
|
|
@ -362,17 +335,16 @@ async def node_fulltext_search(
|
||||||
async def node_similarity_search(
|
async def node_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
group_ids = group_ids if group_ids is not None else [None]
|
|
||||||
|
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
MATCH (n WHERE n.group_id IN $group_ids)
|
MATCH (n:Entity)
|
||||||
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
|
|
@ -394,18 +366,17 @@ async def node_similarity_search(
|
||||||
async def community_fulltext_search(
|
async def community_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
group_ids = group_ids if group_ids is not None else [None]
|
|
||||||
|
|
||||||
# BM25 search to get top communities
|
# BM25 search to get top communities
|
||||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||||
YIELD node AS comm, score
|
YIELD node AS comm, score
|
||||||
MATCH (comm WHERE comm.group_id in $group_ids)
|
MATCH (comm:Community)
|
||||||
|
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid AS uuid,
|
comm.uuid AS uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
|
|
@ -428,17 +399,16 @@ async def community_fulltext_search(
|
||||||
async def community_similarity_search(
|
async def community_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
group_ids = group_ids if group_ids is not None else [None]
|
|
||||||
|
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
||||||
YIELD node AS comm, score
|
YIELD node AS comm, score
|
||||||
MATCH (comm WHERE comm.group_id IN $group_ids)
|
MATCH (comm:Community)
|
||||||
|
WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid As uuid,
|
comm.uuid As uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
|
|
@ -461,7 +431,7 @@ async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -503,7 +473,6 @@ async def hybrid_node_search(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
results: list[list[EntityNode]] = list(
|
results: list[list[EntityNode]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
||||||
|
|
@ -625,10 +594,10 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
async def node_distance_reranker(
|
async def node_distance_reranker(
|
||||||
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
|
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# use rrf as a preliminary ranker
|
# filter out node_uuid center node node uuid
|
||||||
sorted_uuids = rrf(node_uuids)
|
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
|
|
@ -644,21 +613,23 @@ async def node_distance_reranker(
|
||||||
node_uuid=uuid,
|
node_uuid=uuid,
|
||||||
center_uuid=center_node_uuid,
|
center_uuid=center_node_uuid,
|
||||||
)
|
)
|
||||||
for uuid in sorted_uuids
|
for uuid in filtered_uuids
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
for uuid, result in zip(sorted_uuids, path_results):
|
for uuid, result in zip(filtered_uuids, path_results):
|
||||||
records = result[0]
|
records = result[0]
|
||||||
record = records[0] if len(records) > 0 else None
|
record = records[0] if len(records) > 0 else None
|
||||||
distance: float = record['score'] if record is not None else float('inf')
|
distance: float = record['score'] if record is not None else float('inf')
|
||||||
distance = 0 if uuid == center_node_uuid else distance
|
|
||||||
scores[uuid] = distance
|
scores[uuid] = distance
|
||||||
|
|
||||||
# rerank on shortest distance
|
# rerank on shortest distance
|
||||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
return sorted_uuids
|
# add back in filtered center uuids
|
||||||
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
||||||
|
|
||||||
|
return filtered_uuids
|
||||||
|
|
||||||
|
|
||||||
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
|
||||||
|
|
||||||
|
|
||||||
async def build_community(
|
async def build_community(
|
||||||
llm_client: LLMClient, community_cluster: list[EntityNode]
|
llm_client: LLMClient, community_cluster: list[EntityNode]
|
||||||
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
||||||
summaries = [entity.summary for entity in community_cluster]
|
summaries = [entity.summary for entity in community_cluster]
|
||||||
length = len(summaries)
|
length = len(summaries)
|
||||||
|
|
@ -168,7 +168,7 @@ async def build_community(
|
||||||
*[
|
*[
|
||||||
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
||||||
for left_summary, right_summary in zip(
|
for left_summary, right_summary in zip(
|
||||||
summaries[: int(length / 2)], summaries[int(length / 2):]
|
summaries[: int(length / 2)], summaries[int(length / 2) :]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -196,7 +196,7 @@ async def build_community(
|
||||||
|
|
||||||
|
|
||||||
async def build_communities(
|
async def build_communities(
|
||||||
driver: AsyncDriver, llm_client: LLMClient
|
driver: AsyncDriver, llm_client: LLMClient
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
community_clusters = await get_community_clusters(driver)
|
community_clusters = await get_community_clusters(driver)
|
||||||
|
|
||||||
|
|
@ -227,7 +227,7 @@ async def remove_communities(driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def determine_entity_community(
|
async def determine_entity_community(
|
||||||
driver: AsyncDriver, entity: EntityNode
|
driver: AsyncDriver, entity: EntityNode
|
||||||
) -> tuple[CommunityNode | None, bool]:
|
) -> tuple[CommunityNode | None, bool]:
|
||||||
# Check if the node is already part of a community
|
# Check if the node is already part of a community
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -288,7 +288,7 @@ async def determine_entity_community(
|
||||||
|
|
||||||
|
|
||||||
async def update_community(
|
async def update_community(
|
||||||
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
||||||
):
|
):
|
||||||
community, is_new = await determine_entity_community(driver, entity)
|
community, is_new = await determine_entity_community(driver, entity)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ async def extract_edges(
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
group_id: str | None,
|
group_id: str = '',
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ async def retrieve_episodes(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -119,7 +119,8 @@ async def retrieve_episodes(
|
||||||
"""
|
"""
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
|
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||||
|
AND ($group_ids IS NULL) OR e.group_id in $group_ids
|
||||||
RETURN e.content AS content,
|
RETURN e.content AS content,
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
e.valid_at AS valid_at,
|
e.valid_at AS valid_at,
|
||||||
|
|
|
||||||
|
|
@ -76,16 +76,18 @@ async def test_graphiti_init():
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
await graphiti.build_communities()
|
await graphiti.build_communities()
|
||||||
|
|
||||||
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
edges = await graphiti.search(
|
||||||
|
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
|
||||||
|
)
|
||||||
|
|
||||||
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
|
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
|
||||||
|
|
||||||
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
|
edges = await graphiti.search('issues with higher ed', group_ids=None)
|
||||||
|
|
||||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
||||||
|
|
||||||
results = await graphiti._search(
|
results = await graphiti._search(
|
||||||
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1']
|
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None
|
||||||
)
|
)
|
||||||
pretty_results = {
|
pretty_results = {
|
||||||
'edges': [edge.fact for edge in results.edges],
|
'edges': [edge.fact for edge in results.edges],
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue