Revert episodes (#387)
* episode search fixes and optimizations * remove extra return string * Update graphiti_core/utils/maintenance/graph_data_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
5b24f591b1
commit
432d2295c6
3 changed files with 50 additions and 37 deletions
|
|
@ -39,8 +39,6 @@ from graphiti_core.utils.datetime_utils import utc_now
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ENTITY_NODE_RETURN: LiteralString = """
|
ENTITY_NODE_RETURN: LiteralString = """
|
||||||
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
|
|
||||||
WITH n, collect(e.uuid) AS episodes
|
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
|
|
@ -49,8 +47,8 @@ ENTITY_NODE_RETURN: LiteralString = """
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary,
|
n.summary AS summary,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS attributes,
|
properties(n) AS attributes
|
||||||
episodes"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class EpisodeType(Enum):
|
class EpisodeType(Enum):
|
||||||
|
|
@ -265,13 +263,35 @@ class EpisodicNode(Node):
|
||||||
|
|
||||||
return episodes
|
return episodes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||||
|
RETURN DISTINCT
|
||||||
|
e.content AS content,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.valid_at AS valid_at,
|
||||||
|
e.uuid AS uuid,
|
||||||
|
e.name AS name,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
e.source_description AS source_description,
|
||||||
|
e.source AS source,
|
||||||
|
e.entity_edges AS entity_edges
|
||||||
|
""",
|
||||||
|
entity_node_uuid=entity_node_uuid,
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
return episodes
|
||||||
|
|
||||||
|
|
||||||
class EntityNode(Node):
|
class EntityNode(Node):
|
||||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||||
episodes: list[str] | None = Field(
|
|
||||||
default=None, description='List of episode uuids that mention this node.'
|
|
||||||
)
|
|
||||||
attributes: dict[str, Any] = Field(
|
attributes: dict[str, Any] = Field(
|
||||||
default={}, description='Additional attributes of the node. Dependent on node labels'
|
default={}, description='Additional attributes of the node. Dependent on node labels'
|
||||||
)
|
)
|
||||||
|
|
@ -312,8 +332,8 @@ class EntityNode(Node):
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
"""
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -519,7 +539,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
attributes=record['attributes'],
|
attributes=record['attributes'],
|
||||||
episodes=record['episodes'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
entity_node.attributes.pop('uuid', None)
|
entity_node.attributes.pop('uuid', None)
|
||||||
|
|
|
||||||
|
|
@ -54,20 +54,6 @@ DEFAULT_MMR_LAMBDA = 0.5
|
||||||
MAX_SEARCH_DEPTH = 3
|
MAX_SEARCH_DEPTH = 3
|
||||||
MAX_QUERY_LENGTH = 32
|
MAX_QUERY_LENGTH = 32
|
||||||
|
|
||||||
SEARCH_ENTITY_NODE_RETURN: LiteralString = """
|
|
||||||
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
|
|
||||||
WITH n, score, collect(e.uuid) AS episodes
|
|
||||||
RETURN
|
|
||||||
n.uuid As uuid,
|
|
||||||
n.name AS name,
|
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes,
|
|
||||||
episodes"""
|
|
||||||
|
|
||||||
|
|
||||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
group_ids_filter_list = (
|
group_ids_filter_list = (
|
||||||
|
|
@ -245,8 +231,8 @@ async def edge_similarity_search(
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
|
|
@ -358,12 +344,12 @@ async def node_fulltext_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WHERE n:Entity
|
WHERE n:Entity
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ SEARCH_ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
"""
|
"""
|
||||||
|
|
@ -416,7 +402,7 @@ async def node_similarity_search(
|
||||||
+ """
|
+ """
|
||||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score"""
|
WHERE score > $min_score"""
|
||||||
+ SEARCH_ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
|
||||||
|
|
@ -132,10 +132,14 @@ async def retrieve_episodes(
|
||||||
Returns:
|
Returns:
|
||||||
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
|
||||||
"""
|
"""
|
||||||
result = await driver.execute_query(
|
group_id_filter: LiteralString = 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
|
||||||
|
|
||||||
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||||
AND ($group_ids IS NULL) OR e.group_id in $group_ids
|
"""
|
||||||
|
+ group_id_filter
|
||||||
|
+ """
|
||||||
RETURN e.content AS content,
|
RETURN e.content AS content,
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
e.valid_at AS valid_at,
|
e.valid_at AS valid_at,
|
||||||
|
|
@ -144,9 +148,13 @@ async def retrieve_episodes(
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.source_description AS source_description,
|
e.source_description AS source_description,
|
||||||
e.source AS source
|
e.source AS source
|
||||||
ORDER BY e.created_at DESC
|
ORDER BY e.valid_at DESC
|
||||||
LIMIT $num_episodes
|
LIMIT $num_episodes
|
||||||
""",
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await driver.execute_query(
|
||||||
|
query,
|
||||||
reference_time=reference_time,
|
reference_time=reference_time,
|
||||||
num_episodes=last_n,
|
num_episodes=last_n,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue