Add episode refactor (#399)
* partial refactor * get relevant nodes refactor * load edges updates * refactor triplets * not there yet * node search update * working refactor * updates * mypy * mypy
This commit is contained in:
parent
15efa37da1
commit
a26b25dc06
13 changed files with 380 additions and 302 deletions
|
|
@ -1,4 +1,10 @@
|
|||
OPENAI_API_KEY=
|
||||
NEO4J_URI=
|
||||
NEO4J_PORT=
|
||||
NEO4J_USER=
|
||||
NEO4J_PASSWORD=
|
||||
DEFAULT_DATABASE=
|
||||
USE_PARALLEL_RUNTIME=
|
||||
SEMAPHORE_LIMIT=
|
||||
GITHUB_SHA=
|
||||
MAX_REFLEXION_ITERATIONS=
|
||||
|
|
@ -37,6 +37,21 @@ from graphiti_core.nodes import Node
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_EDGE_RETURN: LiteralString = """
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
startNode(e).uuid AS source_node_uuid,
|
||||
endNode(e).uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at"""
|
||||
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
|
@ -234,20 +249,8 @@ class EntityEdge(Edge):
|
|||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
|
|
@ -268,20 +271,8 @@ class EntityEdge(Edge):
|
|||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
|
|
@ -308,20 +299,8 @@ class EntityEdge(Edge):
|
|||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ ENTITY_EDGE_RETURN
|
||||
+ """
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
|
|
@ -340,22 +319,12 @@ class EntityEdge(Edge):
|
|||
|
||||
@classmethod
|
||||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
RETURN DISTINCT
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at
|
||||
"""
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
|
|
@ -42,7 +43,6 @@ from graphiti_core.search.search_utils import (
|
|||
RELEVANT_SCHEMA_LIMIT,
|
||||
get_mentioned_nodes,
|
||||
get_relevant_edges,
|
||||
get_relevant_nodes,
|
||||
)
|
||||
from graphiti_core.utils.bulk_utils import (
|
||||
RawEpisode,
|
||||
|
|
@ -150,6 +150,13 @@ class Graphiti:
|
|||
else:
|
||||
self.cross_encoder = OpenAIRerankerClient()
|
||||
|
||||
self.clients = GraphitiClients(
|
||||
driver=self.driver,
|
||||
llm_client=self.llm_client,
|
||||
embedder=self.embedder,
|
||||
cross_encoder=self.cross_encoder,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close the connection to the Neo4j database.
|
||||
|
|
@ -222,6 +229,7 @@ class Graphiti:
|
|||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str] | None = None,
|
||||
source: EpisodeType | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
|
|
@ -248,7 +256,7 @@ class Graphiti:
|
|||
The actual retrieval is performed by the `retrieve_episodes` function
|
||||
from the `graphiti_core.utils` module.
|
||||
"""
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
|
|
@ -314,15 +322,16 @@ class Graphiti:
|
|||
"""
|
||||
try:
|
||||
start = time()
|
||||
|
||||
entity_edges: list[EntityEdge] = []
|
||||
now = utc_now()
|
||||
|
||||
validate_entity_types(entity_types)
|
||||
|
||||
previous_episodes = (
|
||||
await self.retrieve_episodes(
|
||||
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
||||
reference_time,
|
||||
last_n=RELEVANT_SCHEMA_LIMIT,
|
||||
group_ids=[group_id],
|
||||
source=source,
|
||||
)
|
||||
if previous_episode_uuids is None
|
||||
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
||||
|
|
@ -346,132 +355,35 @@ class Graphiti:
|
|||
# Extract entities as nodes
|
||||
|
||||
extracted_nodes = await extract_nodes(
|
||||
self.llm_client, episode, previous_episodes, entity_types
|
||||
)
|
||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
|
||||
# Calculate Embeddings
|
||||
|
||||
await semaphore_gather(
|
||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
||||
self.clients, episode, previous_episodes, entity_types
|
||||
)
|
||||
|
||||
# Find relevant nodes already in the graph
|
||||
existing_nodes_lists: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
get_relevant_nodes(self.driver, SearchFilters(), [node])
|
||||
for node in extracted_nodes
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Resolve extracted nodes with nodes already in the graph and extract facts
|
||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
|
||||
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
|
||||
# Extract edges and resolve nodes
|
||||
(nodes, uuid_map), extracted_edges = await semaphore_gather(
|
||||
resolve_extracted_nodes(
|
||||
self.llm_client,
|
||||
self.clients,
|
||||
extracted_nodes,
|
||||
existing_nodes_lists,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
),
|
||||
extract_edges(
|
||||
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
||||
),
|
||||
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
|
||||
)
|
||||
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
||||
nodes = mentioned_nodes
|
||||
|
||||
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
||||
extracted_edges, uuid_map
|
||||
)
|
||||
|
||||
# calculate embeddings
|
||||
await semaphore_gather(
|
||||
*[
|
||||
edge.generate_embedding(self.embedder)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
|
||||
# Resolve extracted edges with related edges already in the graph
|
||||
related_edges_list: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
[edge],
|
||||
edge.source_node_uuid,
|
||||
edge.target_node_uuid,
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
||||
)
|
||||
logger.debug(
|
||||
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
||||
)
|
||||
|
||||
existing_source_edges_list: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
[edge],
|
||||
edge.source_node_uuid,
|
||||
None,
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
existing_target_edges_list: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
self.driver,
|
||||
[edge],
|
||||
None,
|
||||
edge.target_node_uuid,
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
)
|
||||
for edge in extracted_edges_with_resolved_pointers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
existing_edges_list: list[list[EntityEdge]] = [
|
||||
source_lst + target_lst
|
||||
for source_lst, target_lst in zip(
|
||||
existing_source_edges_list, existing_target_edges_list, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||
self.llm_client,
|
||||
self.clients,
|
||||
extracted_edges_with_resolved_pointers,
|
||||
related_edges_list,
|
||||
existing_edges_list,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
|
||||
entity_edges.extend(resolved_edges + invalidated_edges)
|
||||
entity_edges = resolved_edges + invalidated_edges
|
||||
|
||||
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
||||
|
||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
||||
|
||||
logger.debug(f'Built episodic edges: {episodic_edges}')
|
||||
episodic_edges = build_episodic_edges(nodes, episode, now)
|
||||
|
||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||
|
||||
|
|
@ -565,7 +477,7 @@ class Graphiti:
|
|||
extracted_nodes,
|
||||
extracted_edges,
|
||||
episodic_edges,
|
||||
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
||||
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
|
||||
|
||||
# Generate embeddings
|
||||
await semaphore_gather(
|
||||
|
|
@ -684,9 +596,7 @@ class Graphiti:
|
|||
|
||||
edges = (
|
||||
await search(
|
||||
self.driver,
|
||||
self.embedder,
|
||||
self.cross_encoder,
|
||||
self.clients,
|
||||
query,
|
||||
group_ids,
|
||||
search_config,
|
||||
|
|
@ -728,9 +638,7 @@ class Graphiti:
|
|||
"""
|
||||
|
||||
return await search(
|
||||
self.driver,
|
||||
self.embedder,
|
||||
self.cross_encoder,
|
||||
self.clients,
|
||||
query,
|
||||
group_ids,
|
||||
config,
|
||||
|
|
@ -761,26 +669,17 @@ class Graphiti:
|
|||
await edge.generate_embedding(self.embedder)
|
||||
|
||||
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
||||
self.llm_client,
|
||||
self.clients,
|
||||
[source_node, target_node],
|
||||
[
|
||||
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
|
||||
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
|
||||
],
|
||||
)
|
||||
|
||||
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
||||
|
||||
related_edges = await get_relevant_edges(
|
||||
self.driver,
|
||||
[updated_edge],
|
||||
source_node_uuid=resolved_nodes[0].uuid,
|
||||
target_node_uuid=resolved_nodes[1].uuid,
|
||||
)
|
||||
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
|
||||
|
||||
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges)
|
||||
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
|
||||
|
||||
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
|
||||
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
|
||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
||||
|
||||
await add_nodes_and_edges_bulk(
|
||||
|
|
|
|||
31
graphiti_core/graphiti_types.py
Normal file
31
graphiti_core/graphiti_types.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from graphiti_core.cross_encoder import CrossEncoderClient
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
|
||||
|
||||
class GraphitiClients(BaseModel):
|
||||
driver: AsyncDriver
|
||||
llm_client: LLMClient
|
||||
embedder: EmbedderClient
|
||||
cross_encoder: CrossEncoderClient
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
|
@ -22,15 +22,20 @@ from datetime import datetime
|
|||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from neo4j import time as neo4j_time
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
load_dotenv()
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
||||
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 2))
|
||||
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 1))
|
||||
DEFAULT_PAGE_LIMIT = 20
|
||||
|
||||
RUNTIME_QUERY: LiteralString = (
|
||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ ENTITY_EDGE_SAVE_BULK = """
|
|||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
||||
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
||||
RETURN r.uuid AS uuid
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
|
||||
COMMUNITY_EDGE_SAVE = """
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ from neo4j import AsyncDriver
|
|||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config import (
|
||||
|
|
@ -62,17 +62,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
clients: GraphitiClients,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
search_filter: SearchFilters,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
query_vector: list[float] | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
driver = clients.driver
|
||||
embedder = clients.embedder
|
||||
cross_encoder = clients.cross_encoder
|
||||
|
||||
if query.strip() == '':
|
||||
return SearchResults(
|
||||
edges=[],
|
||||
|
|
@ -80,7 +84,11 @@ async def search(
|
|||
episodes=[],
|
||||
communities=[],
|
||||
)
|
||||
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||
query_vector = (
|
||||
query_vector
|
||||
if query_vector is not None
|
||||
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||
)
|
||||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids else None
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from typing_extensions import LiteralString
|
|||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.helpers import (
|
||||
DEFAULT_DATABASE,
|
||||
USE_PARALLEL_RUNTIME,
|
||||
RUNTIME_QUERY,
|
||||
lucene_sanitize,
|
||||
normalize_l2,
|
||||
semaphore_gather,
|
||||
|
|
@ -207,10 +207,6 @@ async def edge_similarity_search(
|
|||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
runtime_query: LiteralString = (
|
||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
|
@ -230,9 +226,10 @@ async def edge_similarity_search(
|
|||
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
|
|
@ -256,7 +253,7 @@ async def edge_similarity_search(
|
|||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
runtime_query + query,
|
||||
query,
|
||||
query_params,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
|
|
@ -344,10 +341,10 @@ async def node_fulltext_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
|
|
@ -378,10 +375,6 @@ async def node_similarity_search(
|
|||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
runtime_query: LiteralString = (
|
||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
group_filter_query: LiteralString = ''
|
||||
|
|
@ -393,7 +386,7 @@ async def node_similarity_search(
|
|||
query_params.update(filter_params)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
runtime_query
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
|
|
@ -542,10 +535,6 @@ async def community_similarity_search(
|
|||
min_score=DEFAULT_MIN_SCORE,
|
||||
) -> list[CommunityNode]:
|
||||
# vector similarity search over entity names
|
||||
runtime_query: LiteralString = (
|
||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
group_filter_query: LiteralString = ''
|
||||
|
|
@ -554,7 +543,7 @@ async def community_similarity_search(
|
|||
query_params['group_ids'] = group_ids
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
runtime_query
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (comm:Community)
|
||||
"""
|
||||
|
|
@ -660,86 +649,204 @@ async def hybrid_node_search(
|
|||
|
||||
async def get_relevant_nodes(
|
||||
driver: AsyncDriver,
|
||||
search_filter: SearchFilters,
|
||||
nodes: list[EntityNode],
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Retrieve relevant nodes based on the provided list of EntityNodes.
|
||||
search_filter: SearchFilters,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[list[EntityNode]]:
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
This method performs a hybrid search using both the names and embeddings
|
||||
of the input nodes to find relevant nodes in the graph database.
|
||||
group_id = nodes[0].group_id
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nodes : list[EntityNode]
|
||||
A list of EntityNode objects to use as the basis for the search.
|
||||
driver : AsyncDriver
|
||||
The Neo4j driver instance for database operations.
|
||||
# vector similarity search over entity names
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[EntityNode]
|
||||
A list of EntityNode objects that are deemed relevant based on the input nodes.
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method uses the hybrid_node_search function to perform the search,
|
||||
which combines fulltext search and vector similarity search.
|
||||
It extracts the names and name embeddings (if available) from the input nodes
|
||||
to use as search criteria.
|
||||
"""
|
||||
relevant_nodes = await hybrid_node_search(
|
||||
[node.name for node in nodes],
|
||||
[node.name_embedding for node in nodes if node.name_embedding is not None],
|
||||
driver,
|
||||
search_filter,
|
||||
[node.group_id for node in nodes],
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
|
||||
WHERE score > $min_score
|
||||
WITH node, n, score
|
||||
ORDER BY score DESC
|
||||
RETURN node.uuid AS search_node_uuid,
|
||||
collect({
|
||||
uuid: n.uuid,
|
||||
name: n.name,
|
||||
name_embedding: n.name_embedding,
|
||||
group_id: n.group_id,
|
||||
created_at: n.created_at,
|
||||
summary: n.summary,
|
||||
labels: labels(n),
|
||||
attributes: properties(n)
|
||||
})[..$limit] AS matches
|
||||
"""
|
||||
)
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
nodes=[
|
||||
{'uuid': node.uuid, 'name': node.name, 'name_embedding': node.name_embedding}
|
||||
for node in nodes
|
||||
],
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
||||
result['search_node_uuid']: [
|
||||
get_entity_node_from_record(record) for record in result['matches']
|
||||
]
|
||||
for result in results
|
||||
}
|
||||
|
||||
relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]
|
||||
|
||||
return relevant_nodes
|
||||
|
||||
|
||||
async def get_relevant_edges(
|
||||
driver: AsyncDriver,
|
||||
edges: list[EntityEdge],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
search_filter: SearchFilters,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
relevant_edges: list[EntityEdge] = []
|
||||
relevant_edge_uuids = set()
|
||||
) -> list[list[EntityEdge]]:
|
||||
if len(edges) == 0:
|
||||
return []
|
||||
|
||||
results = await semaphore_gather(
|
||||
*[
|
||||
edge_similarity_search(
|
||||
driver,
|
||||
edge.fact_embedding,
|
||||
source_node_uuid,
|
||||
target_node_uuid,
|
||||
SearchFilters(),
|
||||
[edge.group_id],
|
||||
limit,
|
||||
)
|
||||
for edge in edges
|
||||
if edge.fact_embedding is not None
|
||||
]
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
||||
WHERE score > $min_score
|
||||
WITH edge, e, score
|
||||
ORDER BY score DESC
|
||||
RETURN edge.uuid AS search_edge_uuid,
|
||||
collect({
|
||||
uuid: e.uuid,
|
||||
source_node_uuid: startNode(e).uuid,
|
||||
target_node_uuid: endNode(e).uuid,
|
||||
created_at: e.created_at,
|
||||
name: e.name,
|
||||
group_id: e.group_id,
|
||||
fact: e.fact,
|
||||
fact_embedding: e.fact_embedding,
|
||||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at
|
||||
})[..$limit] AS matches
|
||||
"""
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for edge in result:
|
||||
if edge.uuid in relevant_edge_uuids:
|
||||
continue
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
||||
result['search_edge_uuid']: [
|
||||
get_entity_edge_from_record(record) for record in result['matches']
|
||||
]
|
||||
for result in results
|
||||
}
|
||||
|
||||
relevant_edge_uuids.add(edge.uuid)
|
||||
relevant_edges.append(edge)
|
||||
|
||||
end = time()
|
||||
logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
||||
relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]
|
||||
|
||||
return relevant_edges
|
||||
|
||||
|
||||
async def get_edge_invalidation_candidates(
|
||||
driver: AsyncDriver,
|
||||
edges: list[EntityEdge],
|
||||
search_filter: SearchFilters,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[list[EntityEdge]]:
|
||||
if len(edges) == 0:
|
||||
return []
|
||||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
||||
WHERE score > $min_score
|
||||
WITH edge, e, score
|
||||
ORDER BY score DESC
|
||||
RETURN edge.uuid AS search_edge_uuid,
|
||||
collect({
|
||||
uuid: e.uuid,
|
||||
source_node_uuid: startNode(e).uuid,
|
||||
target_node_uuid: endNode(e).uuid,
|
||||
created_at: e.created_at,
|
||||
name: e.name,
|
||||
group_id: e.group_id,
|
||||
fact: e.fact,
|
||||
fact_embedding: e.fact_embedding,
|
||||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at
|
||||
})[..$limit] AS matches
|
||||
"""
|
||||
)
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
||||
result['search_edge_uuid']: [
|
||||
get_entity_edge_from_record(record) for record in result['matches']
|
||||
]
|
||||
for result in results
|
||||
}
|
||||
|
||||
invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
|
||||
|
||||
return invalidation_edges
|
||||
|
||||
|
||||
# takes in a list of rankings of uuids
|
||||
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
|
||||
scores: dict[str, float] = defaultdict(float)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
|
|
@ -128,16 +129,18 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
|
||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
|
||||
await tx.run(
|
||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||
)
|
||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||
extracted_nodes_bulk = await semaphore_gather(
|
||||
*[
|
||||
extract_nodes(llm_client, episode, previous_episodes)
|
||||
extract_nodes(clients, episode, previous_episodes)
|
||||
for episode, previous_episodes in episode_tuples
|
||||
]
|
||||
)
|
||||
|
|
@ -150,7 +153,7 @@ async def extract_nodes_and_edges_bulk(
|
|||
extracted_edges_bulk = await semaphore_gather(
|
||||
*[
|
||||
extract_edges(
|
||||
llm_client,
|
||||
clients,
|
||||
episode,
|
||||
extracted_nodes_bulk[i],
|
||||
previous_episodes_list[i],
|
||||
|
|
@ -189,7 +192,7 @@ async def dedupe_nodes_bulk(
|
|||
|
||||
existing_nodes_chunks: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
|
||||
*[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -223,7 +226,7 @@ async def dedupe_edges_bulk(
|
|||
|
||||
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
||||
*[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,12 +19,15 @@ from datetime import datetime
|
|||
from time import time
|
||||
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
|
|
@ -39,7 +42,7 @@ def build_episodic_edges(
|
|||
episode: EpisodicNode,
|
||||
created_at: datetime,
|
||||
) -> list[EpisodicEdge]:
|
||||
edges: list[EpisodicEdge] = [
|
||||
episodic_edges: list[EpisodicEdge] = [
|
||||
EpisodicEdge(
|
||||
source_node_uuid=episode.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
|
|
@ -49,7 +52,9 @@ def build_episodic_edges(
|
|||
for node in entity_nodes
|
||||
]
|
||||
|
||||
return edges
|
||||
logger.debug(f'Built episodic edges: {episodic_edges}')
|
||||
|
||||
return episodic_edges
|
||||
|
||||
|
||||
def build_community_edges(
|
||||
|
|
@ -71,7 +76,7 @@ def build_community_edges(
|
|||
|
||||
|
||||
async def extract_edges(
|
||||
llm_client: LLMClient,
|
||||
clients: GraphitiClients,
|
||||
episode: EpisodicNode,
|
||||
nodes: list[EntityNode],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
|
|
@ -79,7 +84,9 @@ async def extract_edges(
|
|||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
EXTRACT_EDGES_MAX_TOKENS = 16384
|
||||
extract_edges_max_tokens = 16384
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||
|
||||
|
|
@ -97,7 +104,7 @@ async def extract_edges(
|
|||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.edge(context),
|
||||
response_model=ExtractedEdges,
|
||||
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
|
||||
max_tokens=extract_edges_max_tokens,
|
||||
)
|
||||
edges_data = llm_response.get('edges', [])
|
||||
|
||||
|
|
@ -145,6 +152,11 @@ async def extract_edges(
|
|||
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
||||
)
|
||||
|
||||
# calculate embeddings
|
||||
await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges])
|
||||
|
||||
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
|
|
@ -193,13 +205,26 @@ async def dedupe_extracted_edges(
|
|||
|
||||
|
||||
async def resolve_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
clients: GraphitiClients,
|
||||
extracted_edges: list[EntityEdge],
|
||||
related_edges_lists: list[list[EntityEdge]],
|
||||
existing_edges_lists: list[list[EntityEdge]],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||
driver = clients.driver
|
||||
llm_client = clients.llm_client
|
||||
|
||||
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
|
||||
driver, extracted_edges, SearchFilters(), 0.8
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
||||
)
|
||||
|
||||
edge_invalidation_candidates: list[list[EntityEdge]] = await get_edge_invalidation_candidates(
|
||||
driver, extracted_edges, SearchFilters()
|
||||
)
|
||||
|
||||
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||
await semaphore_gather(
|
||||
|
|
@ -213,7 +238,7 @@ async def resolve_extracted_edges(
|
|||
previous_episodes,
|
||||
)
|
||||
for extracted_edge, related_edges, existing_edges in zip(
|
||||
extracted_edges, related_edges_lists, existing_edges_lists, strict=False
|
||||
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -228,6 +253,8 @@ async def resolve_extracted_edges(
|
|||
resolved_edges.append(resolved_edge)
|
||||
invalidated_edges.extend(invalidated_edge_chunk)
|
||||
|
||||
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
||||
|
||||
return resolved_edges, invalidated_edges
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ async def retrieve_episodes(
|
|||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str] | None = None,
|
||||
source: EpisodeType | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
|
|
@ -132,13 +133,17 @@ async def retrieve_episodes(
|
|||
Returns:
|
||||
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
||||
"""
|
||||
group_id_filter: LiteralString = 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||
group_id_filter: LiteralString = (
|
||||
'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||
)
|
||||
source_filter: LiteralString = 'AND e.source = $source' if source is not None else ''
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ group_id_filter
|
||||
+ source_filter
|
||||
+ """
|
||||
RETURN e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
|
|
@ -156,6 +161,7 @@ async def retrieve_episodes(
|
|||
result = await driver.execute_query(
|
||||
query,
|
||||
reference_time=reference_time,
|
||||
source=source,
|
||||
num_episodes=last_n,
|
||||
group_ids=group_ids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from typing import Any
|
|||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
|
|
@ -29,6 +30,8 @@ from graphiti_core.prompts import prompt_library
|
|||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
||||
from graphiti_core.prompts.summarize_nodes import Summary
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import get_relevant_nodes
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -116,12 +119,14 @@ async def extract_nodes_reflexion(
|
|||
|
||||
|
||||
async def extract_nodes(
|
||||
llm_client: LLMClient,
|
||||
clients: GraphitiClients,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
extracted_node_names: list[str] = []
|
||||
custom_prompt = ''
|
||||
entities_missed = True
|
||||
|
|
@ -138,7 +143,6 @@ async def extract_nodes(
|
|||
elif episode.source == EpisodeType.json:
|
||||
extracted_node_names = await extract_json_nodes(llm_client, episode, custom_prompt)
|
||||
|
||||
reflexion_iterations += 1
|
||||
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||
missing_entities = await extract_nodes_reflexion(
|
||||
llm_client, episode, previous_episodes, extracted_node_names
|
||||
|
|
@ -149,6 +153,7 @@ async def extract_nodes(
|
|||
custom_prompt = 'The following entities were missed in a previous extraction: '
|
||||
for entity in missing_entities:
|
||||
custom_prompt += f'\n{entity},'
|
||||
reflexion_iterations += 1
|
||||
|
||||
node_classification_context = {
|
||||
'episode_content': episode.content,
|
||||
|
|
@ -184,7 +189,7 @@ async def extract_nodes(
|
|||
end = time()
|
||||
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
||||
# Convert the extracted data into EntityNode objects
|
||||
new_nodes = []
|
||||
extracted_nodes = []
|
||||
for name in extracted_node_names:
|
||||
entity_type = node_classifications.get(name)
|
||||
if entity_types is not None and entity_type not in entity_types:
|
||||
|
|
@ -203,10 +208,13 @@ async def extract_nodes(
|
|||
summary='',
|
||||
created_at=utc_now(),
|
||||
)
|
||||
new_nodes.append(new_node)
|
||||
extracted_nodes.append(new_node)
|
||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||
|
||||
return new_nodes
|
||||
await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes])
|
||||
|
||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
return extracted_nodes
|
||||
|
||||
|
||||
async def dedupe_extracted_nodes(
|
||||
|
|
@ -260,13 +268,20 @@ async def dedupe_extracted_nodes(
|
|||
|
||||
|
||||
async def resolve_extracted_nodes(
|
||||
llm_client: LLMClient,
|
||||
clients: GraphitiClients,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes_lists: list[list[EntityNode]],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
llm_client = clients.llm_client
|
||||
driver = clients.driver
|
||||
|
||||
# Find relevant nodes already in the graph
|
||||
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
|
||||
driver, extracted_nodes, SearchFilters(), 0.8
|
||||
)
|
||||
|
||||
uuid_map: dict[str, str] = {}
|
||||
resolved_nodes: list[EntityNode] = []
|
||||
results: list[tuple[EntityNode, dict[str, str]]] = list(
|
||||
|
|
@ -291,6 +306,8 @@ async def resolve_extracted_nodes(
|
|||
uuid_map.update(result[1])
|
||||
resolved_nodes.append(result[0])
|
||||
|
||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||
|
||||
return resolved_nodes, uuid_map
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.10.5"
|
||||
version = "0.10.6"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue