Add episode refactor (#399)
* partial refactor * get relevant nodes refactor * load edges updates * refactor triplets * not there yet * node search update * working refactor * updates * mypy * mypy
This commit is contained in:
parent
15efa37da1
commit
a26b25dc06
13 changed files with 380 additions and 302 deletions
|
|
@ -1,4 +1,10 @@
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
NEO4J_URI=
|
NEO4J_URI=
|
||||||
|
NEO4J_PORT=
|
||||||
NEO4J_USER=
|
NEO4J_USER=
|
||||||
NEO4J_PASSWORD=
|
NEO4J_PASSWORD=
|
||||||
|
DEFAULT_DATABASE=
|
||||||
|
USE_PARALLEL_RUNTIME=
|
||||||
|
SEMAPHORE_LIMIT=
|
||||||
|
GITHUB_SHA=
|
||||||
|
MAX_REFLEXION_ITERATIONS=
|
||||||
|
|
@ -37,6 +37,21 @@ from graphiti_core.nodes import Node
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ENTITY_EDGE_RETURN: LiteralString = """
|
||||||
|
RETURN
|
||||||
|
e.uuid AS uuid,
|
||||||
|
startNode(e).uuid AS source_node_uuid,
|
||||||
|
endNode(e).uuid AS target_node_uuid,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.name AS name,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
e.fact AS fact,
|
||||||
|
e.fact_embedding AS fact_embedding,
|
||||||
|
e.episodes AS episodes,
|
||||||
|
e.expired_at AS expired_at,
|
||||||
|
e.valid_at AS valid_at,
|
||||||
|
e.invalid_at AS invalid_at"""
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel, ABC):
|
class Edge(BaseModel, ABC):
|
||||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
|
@ -234,20 +249,8 @@ class EntityEdge(Edge):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
RETURN
|
"""
|
||||||
e.uuid AS uuid,
|
+ ENTITY_EDGE_RETURN,
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.fact AS fact,
|
|
||||||
e.fact_embedding AS fact_embedding,
|
|
||||||
e.episodes AS episodes,
|
|
||||||
e.expired_at AS expired_at,
|
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.invalid_at AS invalid_at
|
|
||||||
""",
|
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
|
@ -268,20 +271,8 @@ class EntityEdge(Edge):
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.uuid IN $uuids
|
WHERE e.uuid IN $uuids
|
||||||
RETURN
|
"""
|
||||||
e.uuid AS uuid,
|
+ ENTITY_EDGE_RETURN,
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.fact AS fact,
|
|
||||||
e.fact_embedding AS fact_embedding,
|
|
||||||
e.episodes AS episodes,
|
|
||||||
e.expired_at AS expired_at,
|
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.invalid_at AS invalid_at
|
|
||||||
""",
|
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
|
@ -308,20 +299,8 @@ class EntityEdge(Edge):
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
|
+ ENTITY_EDGE_RETURN
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
|
||||||
e.uuid AS uuid,
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.fact AS fact,
|
|
||||||
e.fact_embedding AS fact_embedding,
|
|
||||||
e.episodes AS episodes,
|
|
||||||
e.expired_at AS expired_at,
|
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.invalid_at AS invalid_at
|
|
||||||
ORDER BY e.uuid DESC
|
ORDER BY e.uuid DESC
|
||||||
"""
|
"""
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
|
|
@ -340,22 +319,12 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||||
query: LiteralString = """
|
query: LiteralString = (
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
"""
|
||||||
RETURN DISTINCT
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
e.uuid AS uuid,
|
"""
|
||||||
n.uuid AS source_node_uuid,
|
+ ENTITY_EDGE_RETURN
|
||||||
m.uuid AS target_node_uuid,
|
)
|
||||||
e.created_at AS created_at,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.fact AS fact,
|
|
||||||
e.fact_embedding AS fact_embedding,
|
|
||||||
e.episodes AS episodes,
|
|
||||||
e.expired_at AS expired_at,
|
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.invalid_at AS invalid_at
|
|
||||||
"""
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
|
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||||
|
|
@ -42,7 +43,6 @@ from graphiti_core.search.search_utils import (
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
get_mentioned_nodes,
|
get_mentioned_nodes,
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
get_relevant_nodes,
|
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.bulk_utils import (
|
from graphiti_core.utils.bulk_utils import (
|
||||||
RawEpisode,
|
RawEpisode,
|
||||||
|
|
@ -150,6 +150,13 @@ class Graphiti:
|
||||||
else:
|
else:
|
||||||
self.cross_encoder = OpenAIRerankerClient()
|
self.cross_encoder = OpenAIRerankerClient()
|
||||||
|
|
||||||
|
self.clients = GraphitiClients(
|
||||||
|
driver=self.driver,
|
||||||
|
llm_client=self.llm_client,
|
||||||
|
embedder=self.embedder,
|
||||||
|
cross_encoder=self.cross_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""
|
"""
|
||||||
Close the connection to the Neo4j database.
|
Close the connection to the Neo4j database.
|
||||||
|
|
@ -222,6 +229,7 @@ class Graphiti:
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
source: EpisodeType | None = None,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -248,7 +256,7 @@ class Graphiti:
|
||||||
The actual retrieval is performed by the `retrieve_episodes` function
|
The actual retrieval is performed by the `retrieve_episodes` function
|
||||||
from the `graphiti_core.utils` module.
|
from the `graphiti_core.utils` module.
|
||||||
"""
|
"""
|
||||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
|
|
@ -314,15 +322,16 @@ class Graphiti:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
entity_edges: list[EntityEdge] = []
|
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
validate_entity_types(entity_types)
|
validate_entity_types(entity_types)
|
||||||
|
|
||||||
previous_episodes = (
|
previous_episodes = (
|
||||||
await self.retrieve_episodes(
|
await self.retrieve_episodes(
|
||||||
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
reference_time,
|
||||||
|
last_n=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
group_ids=[group_id],
|
||||||
|
source=source,
|
||||||
)
|
)
|
||||||
if previous_episode_uuids is None
|
if previous_episode_uuids is None
|
||||||
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
||||||
|
|
@ -346,132 +355,35 @@ class Graphiti:
|
||||||
# Extract entities as nodes
|
# Extract entities as nodes
|
||||||
|
|
||||||
extracted_nodes = await extract_nodes(
|
extracted_nodes = await extract_nodes(
|
||||||
self.llm_client, episode, previous_episodes, entity_types
|
self.clients, episode, previous_episodes, entity_types
|
||||||
)
|
|
||||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
||||||
|
|
||||||
# Calculate Embeddings
|
|
||||||
|
|
||||||
await semaphore_gather(
|
|
||||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find relevant nodes already in the graph
|
# Extract edges and resolve nodes
|
||||||
existing_nodes_lists: list[list[EntityNode]] = list(
|
(nodes, uuid_map), extracted_edges = await semaphore_gather(
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
get_relevant_nodes(self.driver, SearchFilters(), [node])
|
|
||||||
for node in extracted_nodes
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resolve extracted nodes with nodes already in the graph and extract facts
|
|
||||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
||||||
|
|
||||||
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
||||||
resolve_extracted_nodes(
|
resolve_extracted_nodes(
|
||||||
self.llm_client,
|
self.clients,
|
||||||
extracted_nodes,
|
extracted_nodes,
|
||||||
existing_nodes_lists,
|
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
entity_types,
|
entity_types,
|
||||||
),
|
),
|
||||||
extract_edges(
|
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
|
||||||
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
||||||
nodes = mentioned_nodes
|
|
||||||
|
|
||||||
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
||||||
extracted_edges, uuid_map
|
extracted_edges, uuid_map
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate embeddings
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
edge.generate_embedding(self.embedder)
|
|
||||||
for edge in extracted_edges_with_resolved_pointers
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resolve extracted edges with related edges already in the graph
|
|
||||||
related_edges_list: list[list[EntityEdge]] = list(
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
get_relevant_edges(
|
|
||||||
self.driver,
|
|
||||||
[edge],
|
|
||||||
edge.source_node_uuid,
|
|
||||||
edge.target_node_uuid,
|
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
|
||||||
)
|
|
||||||
for edge in extracted_edges_with_resolved_pointers
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_source_edges_list: list[list[EntityEdge]] = list(
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
get_relevant_edges(
|
|
||||||
self.driver,
|
|
||||||
[edge],
|
|
||||||
edge.source_node_uuid,
|
|
||||||
None,
|
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
|
||||||
)
|
|
||||||
for edge in extracted_edges_with_resolved_pointers
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_target_edges_list: list[list[EntityEdge]] = list(
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
get_relevant_edges(
|
|
||||||
self.driver,
|
|
||||||
[edge],
|
|
||||||
None,
|
|
||||||
edge.target_node_uuid,
|
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
|
||||||
)
|
|
||||||
for edge in extracted_edges_with_resolved_pointers
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_edges_list: list[list[EntityEdge]] = [
|
|
||||||
source_lst + target_lst
|
|
||||||
for source_lst, target_lst in zip(
|
|
||||||
existing_source_edges_list, existing_target_edges_list, strict=False
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||||
self.llm_client,
|
self.clients,
|
||||||
extracted_edges_with_resolved_pointers,
|
extracted_edges_with_resolved_pointers,
|
||||||
related_edges_list,
|
|
||||||
existing_edges_list,
|
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
entity_edges.extend(resolved_edges + invalidated_edges)
|
entity_edges = resolved_edges + invalidated_edges
|
||||||
|
|
||||||
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
episodic_edges = build_episodic_edges(nodes, episode, now)
|
||||||
|
|
||||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
|
||||||
|
|
||||||
logger.debug(f'Built episodic edges: {episodic_edges}')
|
|
||||||
|
|
||||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||||
|
|
||||||
|
|
@ -565,7 +477,7 @@ class Graphiti:
|
||||||
extracted_nodes,
|
extracted_nodes,
|
||||||
extracted_edges,
|
extracted_edges,
|
||||||
episodic_edges,
|
episodic_edges,
|
||||||
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
|
||||||
|
|
||||||
# Generate embeddings
|
# Generate embeddings
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
|
|
@ -684,9 +596,7 @@ class Graphiti:
|
||||||
|
|
||||||
edges = (
|
edges = (
|
||||||
await search(
|
await search(
|
||||||
self.driver,
|
self.clients,
|
||||||
self.embedder,
|
|
||||||
self.cross_encoder,
|
|
||||||
query,
|
query,
|
||||||
group_ids,
|
group_ids,
|
||||||
search_config,
|
search_config,
|
||||||
|
|
@ -728,9 +638,7 @@ class Graphiti:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await search(
|
return await search(
|
||||||
self.driver,
|
self.clients,
|
||||||
self.embedder,
|
|
||||||
self.cross_encoder,
|
|
||||||
query,
|
query,
|
||||||
group_ids,
|
group_ids,
|
||||||
config,
|
config,
|
||||||
|
|
@ -761,26 +669,17 @@ class Graphiti:
|
||||||
await edge.generate_embedding(self.embedder)
|
await edge.generate_embedding(self.embedder)
|
||||||
|
|
||||||
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
||||||
self.llm_client,
|
self.clients,
|
||||||
[source_node, target_node],
|
[source_node, target_node],
|
||||||
[
|
|
||||||
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
|
|
||||||
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
||||||
|
|
||||||
related_edges = await get_relevant_edges(
|
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
|
||||||
self.driver,
|
|
||||||
[updated_edge],
|
|
||||||
source_node_uuid=resolved_nodes[0].uuid,
|
|
||||||
target_node_uuid=resolved_nodes[1].uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges)
|
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
|
||||||
|
|
||||||
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
|
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
|
||||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
||||||
|
|
||||||
await add_nodes_and_edges_bulk(
|
await add_nodes_and_edges_bulk(
|
||||||
|
|
|
||||||
31
graphiti_core/graphiti_types.py
Normal file
31
graphiti_core/graphiti_types.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from neo4j import AsyncDriver
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from graphiti_core.cross_encoder import CrossEncoderClient
|
||||||
|
from graphiti_core.embedder import EmbedderClient
|
||||||
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiClients(BaseModel):
|
||||||
|
driver: AsyncDriver
|
||||||
|
llm_client: LLMClient
|
||||||
|
embedder: EmbedderClient
|
||||||
|
cross_encoder: CrossEncoderClient
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
@ -22,15 +22,20 @@ from datetime import datetime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from neo4j import time as neo4j_time
|
from neo4j import time as neo4j_time
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
||||||
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 2))
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 1))
|
||||||
DEFAULT_PAGE_LIMIT = 20
|
DEFAULT_PAGE_LIMIT = 20
|
||||||
|
|
||||||
|
RUNTIME_QUERY: LiteralString = (
|
||||||
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||||
return neo_date.to_native() if neo_date else None
|
return neo_date.to_native() if neo_date else None
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ ENTITY_EDGE_SAVE_BULK = """
|
||||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
||||||
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
||||||
RETURN r.uuid AS uuid
|
RETURN edge.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
COMMUNITY_EDGE_SAVE = """
|
COMMUNITY_EDGE_SAVE = """
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,8 @@ from neo4j import AsyncDriver
|
||||||
|
|
||||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.embedder import EmbedderClient
|
|
||||||
from graphiti_core.errors import SearchRerankerError
|
from graphiti_core.errors import SearchRerankerError
|
||||||
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import semaphore_gather
|
from graphiti_core.helpers import semaphore_gather
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
from graphiti_core.search.search_config import (
|
from graphiti_core.search.search_config import (
|
||||||
|
|
@ -62,17 +62,21 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
driver: AsyncDriver,
|
clients: GraphitiClients,
|
||||||
embedder: EmbedderClient,
|
|
||||||
cross_encoder: CrossEncoderClient,
|
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
bfs_origin_node_uuids: list[str] | None = None,
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
|
query_vector: list[float] | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
driver = clients.driver
|
||||||
|
embedder = clients.embedder
|
||||||
|
cross_encoder = clients.cross_encoder
|
||||||
|
|
||||||
if query.strip() == '':
|
if query.strip() == '':
|
||||||
return SearchResults(
|
return SearchResults(
|
||||||
edges=[],
|
edges=[],
|
||||||
|
|
@ -80,7 +84,11 @@ async def search(
|
||||||
episodes=[],
|
episodes=[],
|
||||||
communities=[],
|
communities=[],
|
||||||
)
|
)
|
||||||
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
|
query_vector = (
|
||||||
|
query_vector
|
||||||
|
if query_vector is not None
|
||||||
|
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||||
|
)
|
||||||
|
|
||||||
# if group_ids is empty, set it to None
|
# if group_ids is empty, set it to None
|
||||||
group_ids = group_ids if group_ids else None
|
group_ids = group_ids if group_ids else None
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from typing_extensions import LiteralString
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
from graphiti_core.helpers import (
|
from graphiti_core.helpers import (
|
||||||
DEFAULT_DATABASE,
|
DEFAULT_DATABASE,
|
||||||
USE_PARALLEL_RUNTIME,
|
RUNTIME_QUERY,
|
||||||
lucene_sanitize,
|
lucene_sanitize,
|
||||||
normalize_l2,
|
normalize_l2,
|
||||||
semaphore_gather,
|
semaphore_gather,
|
||||||
|
|
@ -207,10 +207,6 @@ async def edge_similarity_search(
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
runtime_query: LiteralString = (
|
|
||||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
||||||
)
|
|
||||||
|
|
||||||
query_params: dict[str, Any] = {}
|
query_params: dict[str, Any] = {}
|
||||||
|
|
||||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
|
|
@ -230,9 +226,10 @@ async def edge_similarity_search(
|
||||||
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
RUNTIME_QUERY
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
+ """
|
||||||
"""
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
|
|
@ -256,7 +253,7 @@ async def edge_similarity_search(
|
||||||
)
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
runtime_query + query,
|
query,
|
||||||
query_params,
|
query_params,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
source_uuid=source_node_uuid,
|
source_uuid=source_node_uuid,
|
||||||
|
|
@ -344,10 +341,10 @@ async def node_fulltext_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WHERE n:Entity
|
WHERE n:Entity
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -378,10 +375,6 @@ async def node_similarity_search(
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
runtime_query: LiteralString = (
|
|
||||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
||||||
)
|
|
||||||
|
|
||||||
query_params: dict[str, Any] = {}
|
query_params: dict[str, Any] = {}
|
||||||
|
|
||||||
group_filter_query: LiteralString = ''
|
group_filter_query: LiteralString = ''
|
||||||
|
|
@ -393,7 +386,7 @@ async def node_similarity_search(
|
||||||
query_params.update(filter_params)
|
query_params.update(filter_params)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
runtime_query
|
RUNTIME_QUERY
|
||||||
+ """
|
+ """
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
|
|
@ -542,10 +535,6 @@ async def community_similarity_search(
|
||||||
min_score=DEFAULT_MIN_SCORE,
|
min_score=DEFAULT_MIN_SCORE,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
runtime_query: LiteralString = (
|
|
||||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
||||||
)
|
|
||||||
|
|
||||||
query_params: dict[str, Any] = {}
|
query_params: dict[str, Any] = {}
|
||||||
|
|
||||||
group_filter_query: LiteralString = ''
|
group_filter_query: LiteralString = ''
|
||||||
|
|
@ -554,7 +543,7 @@ async def community_similarity_search(
|
||||||
query_params['group_ids'] = group_ids
|
query_params['group_ids'] = group_ids
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
runtime_query
|
RUNTIME_QUERY
|
||||||
+ """
|
+ """
|
||||||
MATCH (comm:Community)
|
MATCH (comm:Community)
|
||||||
"""
|
"""
|
||||||
|
|
@ -660,86 +649,204 @@ async def hybrid_node_search(
|
||||||
|
|
||||||
async def get_relevant_nodes(
|
async def get_relevant_nodes(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_filter: SearchFilters,
|
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
) -> list[EntityNode]:
|
search_filter: SearchFilters,
|
||||||
"""
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
Retrieve relevant nodes based on the provided list of EntityNodes.
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[list[EntityNode]]:
|
||||||
|
if len(nodes) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
This method performs a hybrid search using both the names and embeddings
|
group_id = nodes[0].group_id
|
||||||
of the input nodes to find relevant nodes in the graph database.
|
|
||||||
|
|
||||||
Parameters
|
# vector similarity search over entity names
|
||||||
----------
|
query_params: dict[str, Any] = {}
|
||||||
nodes : list[EntityNode]
|
|
||||||
A list of EntityNode objects to use as the basis for the search.
|
|
||||||
driver : AsyncDriver
|
|
||||||
The Neo4j driver instance for database operations.
|
|
||||||
|
|
||||||
Returns
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
-------
|
query_params.update(filter_params)
|
||||||
list[EntityNode]
|
|
||||||
A list of EntityNode objects that are deemed relevant based on the input nodes.
|
|
||||||
|
|
||||||
Notes
|
query = (
|
||||||
-----
|
RUNTIME_QUERY
|
||||||
This method uses the hybrid_node_search function to perform the search,
|
+ """UNWIND $nodes AS node
|
||||||
which combines fulltext search and vector similarity search.
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
It extracts the names and name embeddings (if available) from the input nodes
|
"""
|
||||||
to use as search criteria.
|
+ filter_query
|
||||||
"""
|
+ """
|
||||||
relevant_nodes = await hybrid_node_search(
|
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
|
||||||
[node.name for node in nodes],
|
WHERE score > $min_score
|
||||||
[node.name_embedding for node in nodes if node.name_embedding is not None],
|
WITH node, n, score
|
||||||
driver,
|
ORDER BY score DESC
|
||||||
search_filter,
|
RETURN node.uuid AS search_node_uuid,
|
||||||
[node.group_id for node in nodes],
|
collect({
|
||||||
|
uuid: n.uuid,
|
||||||
|
name: n.name,
|
||||||
|
name_embedding: n.name_embedding,
|
||||||
|
group_id: n.group_id,
|
||||||
|
created_at: n.created_at,
|
||||||
|
summary: n.summary,
|
||||||
|
labels: labels(n),
|
||||||
|
attributes: properties(n)
|
||||||
|
})[..$limit] AS matches
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
query_params,
|
||||||
|
nodes=[
|
||||||
|
{'uuid': node.uuid, 'name': node.name, 'name_embedding': node.name_embedding}
|
||||||
|
for node in nodes
|
||||||
|
],
|
||||||
|
group_id=group_id,
|
||||||
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
||||||
|
result['search_node_uuid']: [
|
||||||
|
get_entity_node_from_record(record) for record in result['matches']
|
||||||
|
]
|
||||||
|
for result in results
|
||||||
|
}
|
||||||
|
|
||||||
|
relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]
|
||||||
|
|
||||||
return relevant_nodes
|
return relevant_nodes
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
source_node_uuid: str | None,
|
search_filter: SearchFilters,
|
||||||
target_node_uuid: str | None,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[list[EntityEdge]]:
|
||||||
start = time()
|
if len(edges) == 0:
|
||||||
relevant_edges: list[EntityEdge] = []
|
return []
|
||||||
relevant_edge_uuids = set()
|
|
||||||
|
|
||||||
results = await semaphore_gather(
|
query_params: dict[str, Any] = {}
|
||||||
*[
|
|
||||||
edge_similarity_search(
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
driver,
|
query_params.update(filter_params)
|
||||||
edge.fact_embedding,
|
|
||||||
source_node_uuid,
|
query = (
|
||||||
target_node_uuid,
|
RUNTIME_QUERY
|
||||||
SearchFilters(),
|
+ """UNWIND $edges AS edge
|
||||||
[edge.group_id],
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
limit,
|
"""
|
||||||
)
|
+ filter_query
|
||||||
for edge in edges
|
+ """
|
||||||
if edge.fact_embedding is not None
|
WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
||||||
]
|
WHERE score > $min_score
|
||||||
|
WITH edge, e, score
|
||||||
|
ORDER BY score DESC
|
||||||
|
RETURN edge.uuid AS search_edge_uuid,
|
||||||
|
collect({
|
||||||
|
uuid: e.uuid,
|
||||||
|
source_node_uuid: startNode(e).uuid,
|
||||||
|
target_node_uuid: endNode(e).uuid,
|
||||||
|
created_at: e.created_at,
|
||||||
|
name: e.name,
|
||||||
|
group_id: e.group_id,
|
||||||
|
fact: e.fact,
|
||||||
|
fact_embedding: e.fact_embedding,
|
||||||
|
episodes: e.episodes,
|
||||||
|
expired_at: e.expired_at,
|
||||||
|
valid_at: e.valid_at,
|
||||||
|
invalid_at: e.invalid_at
|
||||||
|
})[..$limit] AS matches
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in results:
|
results, _, _ = await driver.execute_query(
|
||||||
for edge in result:
|
query,
|
||||||
if edge.uuid in relevant_edge_uuids:
|
query_params,
|
||||||
continue
|
edges=[edge.model_dump() for edge in edges],
|
||||||
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
||||||
|
result['search_edge_uuid']: [
|
||||||
|
get_entity_edge_from_record(record) for record in result['matches']
|
||||||
|
]
|
||||||
|
for result in results
|
||||||
|
}
|
||||||
|
|
||||||
relevant_edge_uuids.add(edge.uuid)
|
relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]
|
||||||
relevant_edges.append(edge)
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
|
||||||
|
|
||||||
return relevant_edges
|
return relevant_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def get_edge_invalidation_candidates(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
edges: list[EntityEdge],
|
||||||
|
search_filter: SearchFilters,
|
||||||
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[list[EntityEdge]]:
|
||||||
|
if len(edges) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_params: dict[str, Any] = {}
|
||||||
|
|
||||||
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
|
query_params.update(filter_params)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
RUNTIME_QUERY
|
||||||
|
+ """UNWIND $edges AS edge
|
||||||
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
|
"""
|
||||||
|
+ filter_query
|
||||||
|
+ """
|
||||||
|
WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
||||||
|
WHERE score > $min_score
|
||||||
|
WITH edge, e, score
|
||||||
|
ORDER BY score DESC
|
||||||
|
RETURN edge.uuid AS search_edge_uuid,
|
||||||
|
collect({
|
||||||
|
uuid: e.uuid,
|
||||||
|
source_node_uuid: startNode(e).uuid,
|
||||||
|
target_node_uuid: endNode(e).uuid,
|
||||||
|
created_at: e.created_at,
|
||||||
|
name: e.name,
|
||||||
|
group_id: e.group_id,
|
||||||
|
fact: e.fact,
|
||||||
|
fact_embedding: e.fact_embedding,
|
||||||
|
episodes: e.episodes,
|
||||||
|
expired_at: e.expired_at,
|
||||||
|
valid_at: e.valid_at,
|
||||||
|
invalid_at: e.invalid_at
|
||||||
|
})[..$limit] AS matches
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
query_params,
|
||||||
|
edges=[edge.model_dump() for edge in edges],
|
||||||
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
||||||
|
result['search_edge_uuid']: [
|
||||||
|
get_entity_edge_from_record(record) for record in result['matches']
|
||||||
|
]
|
||||||
|
for result in results
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
|
||||||
|
|
||||||
|
return invalidation_edges
|
||||||
|
|
||||||
|
|
||||||
# takes in a list of rankings of uuids
|
# takes in a list of rankings of uuids
|
||||||
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
|
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
|
||||||
scores: dict[str, float] = defaultdict(float)
|
scores: dict[str, float] = defaultdict(float)
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from pydantic import BaseModel
|
||||||
from typing_extensions import Any
|
from typing_extensions import Any
|
||||||
|
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||||
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
|
|
@ -128,16 +129,18 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
|
|
||||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||||
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
|
await tx.run(
|
||||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||||
|
)
|
||||||
|
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||||
extracted_nodes_bulk = await semaphore_gather(
|
extracted_nodes_bulk = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
extract_nodes(llm_client, episode, previous_episodes)
|
extract_nodes(clients, episode, previous_episodes)
|
||||||
for episode, previous_episodes in episode_tuples
|
for episode, previous_episodes in episode_tuples
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -150,7 +153,7 @@ async def extract_nodes_and_edges_bulk(
|
||||||
extracted_edges_bulk = await semaphore_gather(
|
extracted_edges_bulk = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
extract_edges(
|
extract_edges(
|
||||||
llm_client,
|
clients,
|
||||||
episode,
|
episode,
|
||||||
extracted_nodes_bulk[i],
|
extracted_nodes_bulk[i],
|
||||||
previous_episodes_list[i],
|
previous_episodes_list[i],
|
||||||
|
|
@ -189,7 +192,7 @@ async def dedupe_nodes_bulk(
|
||||||
|
|
||||||
existing_nodes_chunks: list[list[EntityNode]] = list(
|
existing_nodes_chunks: list[list[EntityNode]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
|
*[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -223,7 +226,7 @@ async def dedupe_edges_bulk(
|
||||||
|
|
||||||
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
*[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,15 @@ from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||||
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
||||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||||
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
|
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||||
extract_edge_dates,
|
extract_edge_dates,
|
||||||
|
|
@ -39,7 +42,7 @@ def build_episodic_edges(
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
created_at: datetime,
|
created_at: datetime,
|
||||||
) -> list[EpisodicEdge]:
|
) -> list[EpisodicEdge]:
|
||||||
edges: list[EpisodicEdge] = [
|
episodic_edges: list[EpisodicEdge] = [
|
||||||
EpisodicEdge(
|
EpisodicEdge(
|
||||||
source_node_uuid=episode.uuid,
|
source_node_uuid=episode.uuid,
|
||||||
target_node_uuid=node.uuid,
|
target_node_uuid=node.uuid,
|
||||||
|
|
@ -49,7 +52,9 @@ def build_episodic_edges(
|
||||||
for node in entity_nodes
|
for node in entity_nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
return edges
|
logger.debug(f'Built episodic edges: {episodic_edges}')
|
||||||
|
|
||||||
|
return episodic_edges
|
||||||
|
|
||||||
|
|
||||||
def build_community_edges(
|
def build_community_edges(
|
||||||
|
|
@ -71,7 +76,7 @@ def build_community_edges(
|
||||||
|
|
||||||
|
|
||||||
async def extract_edges(
|
async def extract_edges(
|
||||||
llm_client: LLMClient,
|
clients: GraphitiClients,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
|
|
@ -79,7 +84,9 @@ async def extract_edges(
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
EXTRACT_EDGES_MAX_TOKENS = 16384
|
extract_edges_max_tokens = 16384
|
||||||
|
llm_client = clients.llm_client
|
||||||
|
embedder = clients.embedder
|
||||||
|
|
||||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||||
|
|
||||||
|
|
@ -97,7 +104,7 @@ async def extract_edges(
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_edges.edge(context),
|
prompt_library.extract_edges.edge(context),
|
||||||
response_model=ExtractedEdges,
|
response_model=ExtractedEdges,
|
||||||
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
|
max_tokens=extract_edges_max_tokens,
|
||||||
)
|
)
|
||||||
edges_data = llm_response.get('edges', [])
|
edges_data = llm_response.get('edges', [])
|
||||||
|
|
||||||
|
|
@ -145,6 +152,11 @@ async def extract_edges(
|
||||||
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# calculate embeddings
|
||||||
|
await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges])
|
||||||
|
|
||||||
|
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -193,13 +205,26 @@ async def dedupe_extracted_edges(
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_edges(
|
async def resolve_extracted_edges(
|
||||||
llm_client: LLMClient,
|
clients: GraphitiClients,
|
||||||
extracted_edges: list[EntityEdge],
|
extracted_edges: list[EntityEdge],
|
||||||
related_edges_lists: list[list[EntityEdge]],
|
|
||||||
existing_edges_lists: list[list[EntityEdge]],
|
|
||||||
current_episode: EpisodicNode,
|
current_episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||||
|
driver = clients.driver
|
||||||
|
llm_client = clients.llm_client
|
||||||
|
|
||||||
|
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
|
||||||
|
driver, extracted_edges, SearchFilters(), 0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
||||||
|
)
|
||||||
|
|
||||||
|
edge_invalidation_candidates: list[list[EntityEdge]] = await get_edge_invalidation_candidates(
|
||||||
|
driver, extracted_edges, SearchFilters()
|
||||||
|
)
|
||||||
|
|
||||||
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
||||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
|
|
@ -213,7 +238,7 @@ async def resolve_extracted_edges(
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
)
|
)
|
||||||
for extracted_edge, related_edges, existing_edges in zip(
|
for extracted_edge, related_edges, existing_edges in zip(
|
||||||
extracted_edges, related_edges_lists, existing_edges_lists, strict=False
|
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -228,6 +253,8 @@ async def resolve_extracted_edges(
|
||||||
resolved_edges.append(resolved_edge)
|
resolved_edges.append(resolved_edge)
|
||||||
invalidated_edges.extend(invalidated_edge_chunk)
|
invalidated_edges.extend(invalidated_edge_chunk)
|
||||||
|
|
||||||
|
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
||||||
|
|
||||||
return resolved_edges, invalidated_edges
|
return resolved_edges, invalidated_edges
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,7 @@ async def retrieve_episodes(
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
source: EpisodeType | None = None,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -132,13 +133,17 @@ async def retrieve_episodes(
|
||||||
Returns:
|
Returns:
|
||||||
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
||||||
"""
|
"""
|
||||||
group_id_filter: LiteralString = 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
group_id_filter: LiteralString = (
|
||||||
|
'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||||
|
)
|
||||||
|
source_filter: LiteralString = 'AND e.source = $source' if source is not None else ''
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||||
"""
|
"""
|
||||||
+ group_id_filter
|
+ group_id_filter
|
||||||
|
+ source_filter
|
||||||
+ """
|
+ """
|
||||||
RETURN e.content AS content,
|
RETURN e.content AS content,
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
|
|
@ -156,6 +161,7 @@ async def retrieve_episodes(
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
reference_time=reference_time,
|
reference_time=reference_time,
|
||||||
|
source=source,
|
||||||
num_episodes=last_n,
|
num_episodes=last_n,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from typing import Any
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
|
|
@ -29,6 +30,8 @@ from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||||
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
||||||
from graphiti_core.prompts.summarize_nodes import Summary
|
from graphiti_core.prompts.summarize_nodes import Summary
|
||||||
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
|
from graphiti_core.search.search_utils import get_relevant_nodes
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -116,12 +119,14 @@ async def extract_nodes_reflexion(
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes(
|
async def extract_nodes(
|
||||||
llm_client: LLMClient,
|
clients: GraphitiClients,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
start = time()
|
start = time()
|
||||||
|
llm_client = clients.llm_client
|
||||||
|
embedder = clients.embedder
|
||||||
extracted_node_names: list[str] = []
|
extracted_node_names: list[str] = []
|
||||||
custom_prompt = ''
|
custom_prompt = ''
|
||||||
entities_missed = True
|
entities_missed = True
|
||||||
|
|
@ -138,7 +143,6 @@ async def extract_nodes(
|
||||||
elif episode.source == EpisodeType.json:
|
elif episode.source == EpisodeType.json:
|
||||||
extracted_node_names = await extract_json_nodes(llm_client, episode, custom_prompt)
|
extracted_node_names = await extract_json_nodes(llm_client, episode, custom_prompt)
|
||||||
|
|
||||||
reflexion_iterations += 1
|
|
||||||
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||||
missing_entities = await extract_nodes_reflexion(
|
missing_entities = await extract_nodes_reflexion(
|
||||||
llm_client, episode, previous_episodes, extracted_node_names
|
llm_client, episode, previous_episodes, extracted_node_names
|
||||||
|
|
@ -149,6 +153,7 @@ async def extract_nodes(
|
||||||
custom_prompt = 'The following entities were missed in a previous extraction: '
|
custom_prompt = 'The following entities were missed in a previous extraction: '
|
||||||
for entity in missing_entities:
|
for entity in missing_entities:
|
||||||
custom_prompt += f'\n{entity},'
|
custom_prompt += f'\n{entity},'
|
||||||
|
reflexion_iterations += 1
|
||||||
|
|
||||||
node_classification_context = {
|
node_classification_context = {
|
||||||
'episode_content': episode.content,
|
'episode_content': episode.content,
|
||||||
|
|
@ -184,7 +189,7 @@ async def extract_nodes(
|
||||||
end = time()
|
end = time()
|
||||||
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
||||||
# Convert the extracted data into EntityNode objects
|
# Convert the extracted data into EntityNode objects
|
||||||
new_nodes = []
|
extracted_nodes = []
|
||||||
for name in extracted_node_names:
|
for name in extracted_node_names:
|
||||||
entity_type = node_classifications.get(name)
|
entity_type = node_classifications.get(name)
|
||||||
if entity_types is not None and entity_type not in entity_types:
|
if entity_types is not None and entity_type not in entity_types:
|
||||||
|
|
@ -203,10 +208,13 @@ async def extract_nodes(
|
||||||
summary='',
|
summary='',
|
||||||
created_at=utc_now(),
|
created_at=utc_now(),
|
||||||
)
|
)
|
||||||
new_nodes.append(new_node)
|
extracted_nodes.append(new_node)
|
||||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||||
|
|
||||||
return new_nodes
|
await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes])
|
||||||
|
|
||||||
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
|
return extracted_nodes
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_nodes(
|
async def dedupe_extracted_nodes(
|
||||||
|
|
@ -260,13 +268,20 @@ async def dedupe_extracted_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_nodes(
|
async def resolve_extracted_nodes(
|
||||||
llm_client: LLMClient,
|
clients: GraphitiClients,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
existing_nodes_lists: list[list[EntityNode]],
|
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
llm_client = clients.llm_client
|
||||||
|
driver = clients.driver
|
||||||
|
|
||||||
|
# Find relevant nodes already in the graph
|
||||||
|
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
|
||||||
|
driver, extracted_nodes, SearchFilters(), 0.8
|
||||||
|
)
|
||||||
|
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
resolved_nodes: list[EntityNode] = []
|
resolved_nodes: list[EntityNode] = []
|
||||||
results: list[tuple[EntityNode, dict[str, str]]] = list(
|
results: list[tuple[EntityNode, dict[str, str]]] = list(
|
||||||
|
|
@ -291,6 +306,8 @@ async def resolve_extracted_nodes(
|
||||||
uuid_map.update(result[1])
|
uuid_map.update(result[1])
|
||||||
resolved_nodes.append(result[0])
|
resolved_nodes.append(result[0])
|
||||||
|
|
||||||
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||||
|
|
||||||
return resolved_nodes, uuid_map
|
return resolved_nodes, uuid_map
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.10.5"
|
version = "0.10.6"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue