add triplet results

This commit is contained in:
prestonrasmussen 2025-09-08 15:16:42 -04:00
parent 469703e180
commit 7ebc5123a5
2 changed files with 68 additions and 62 deletions

View file

@ -122,19 +122,24 @@ class AddBulkEpisodeResults(BaseModel):
community_edges: list[CommunityEdge] community_edges: list[CommunityEdge]
class AddTripletResults(BaseModel):
nodes: list[EntityNode]
edges: list[EntityEdge]
class Graphiti: class Graphiti:
def __init__( def __init__(
self, self,
uri: str | None = None, uri: str | None = None,
user: str | None = None, user: str | None = None,
password: str | None = None, password: str | None = None,
llm_client: LLMClient | None = None, llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None, embedder: EmbedderClient | None = None,
cross_encoder: CrossEncoderClient | None = None, cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True, store_raw_episode_content: bool = True,
graph_driver: GraphDriver | None = None, graph_driver: GraphDriver | None = None,
max_coroutines: int | None = None, max_coroutines: int | None = None,
ensure_ascii: bool = False, ensure_ascii: bool = False,
): ):
""" """
Initialize a Graphiti instance. Initialize a Graphiti instance.
@ -345,11 +350,11 @@ class Graphiti:
await build_indices_and_constraints(self.driver, delete_existing) await build_indices_and_constraints(self.driver, delete_existing)
async def retrieve_episodes( async def retrieve_episodes(
self, self,
reference_time: datetime, reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN, last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
source: EpisodeType | None = None, source: EpisodeType | None = None,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
""" """
Retrieve the last n episodic nodes from the graph. Retrieve the last n episodic nodes from the graph.
@ -379,20 +384,20 @@ class Graphiti:
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source) return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
async def add_episode( async def add_episode(
self, self,
name: str, name: str,
episode_body: str, episode_body: str,
source_description: str, source_description: str,
reference_time: datetime, reference_time: datetime,
source: EpisodeType = EpisodeType.message, source: EpisodeType = EpisodeType.message,
group_id: str | None = None, group_id: str | None = None,
uuid: str | None = None, uuid: str | None = None,
update_communities: bool = False, update_communities: bool = False,
entity_types: dict[str, type[BaseModel]] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None, excluded_entity_types: list[str] | None = None,
previous_episode_uuids: list[str] | None = None, previous_episode_uuids: list[str] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None, edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None,
) -> AddEpisodeResults: ) -> AddEpisodeResults:
""" """
Process an episode and update the graph. Process an episode and update the graph.
@ -582,13 +587,13 @@ class Graphiti:
raise e raise e
async def add_episode_bulk( async def add_episode_bulk(
self, self,
bulk_episodes: list[RawEpisode], bulk_episodes: list[RawEpisode],
group_id: str | None = None, group_id: str | None = None,
entity_types: dict[str, type[BaseModel]] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None, excluded_entity_types: list[str] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None, edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_map: dict[tuple[str, str], list[str]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None,
) -> AddBulkEpisodeResults: ) -> AddBulkEpisodeResults:
""" """
Process multiple episodes in bulk and update the graph. Process multiple episodes in bulk and update the graph.
@ -870,7 +875,7 @@ class Graphiti:
raise e raise e
async def build_communities( async def build_communities(
self, group_ids: list[str] | None = None self, group_ids: list[str] | None = None
) -> tuple[list[CommunityNode], list[CommunityEdge]]: ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
""" """
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
@ -903,12 +908,12 @@ class Graphiti:
return community_nodes, community_edges return community_nodes, community_edges
async def search( async def search(
self, self,
query: str, query: str,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT, num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None, search_filter: SearchFilters | None = None,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
""" """
Perform a hybrid search on the knowledge graph. Perform a hybrid search on the knowledge graph.
@ -962,13 +967,13 @@ class Graphiti:
return edges return edges
async def _search( async def _search(
self, self,
query: str, query: str,
config: SearchConfig, config: SearchConfig,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None, search_filter: SearchFilters | None = None,
) -> SearchResults: ) -> SearchResults:
"""DEPRECATED""" """DEPRECATED"""
return await self.search_( return await self.search_(
@ -976,13 +981,13 @@ class Graphiti:
) )
async def search_( async def search_(
self, self,
query: str, query: str,
config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER, config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None, search_filter: SearchFilters | None = None,
) -> SearchResults: ) -> SearchResults:
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@ -1015,8 +1020,9 @@ class Graphiti:
return SearchResults(edges=edges, nodes=nodes) return SearchResults(edges=edges, nodes=nodes)
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, async def add_triplet(
target_node: EntityNode) -> AddEpisodeResults: self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
) -> AddTripletResults:
if source_node.name_embedding is None: if source_node.name_embedding is None:
await source_node.generate_name_embedding(self.embedder) await source_node.generate_name_embedding(self.embedder)
if target_node.name_embedding is None: if target_node.name_embedding is None:
@ -1060,7 +1066,7 @@ class Graphiti:
await create_entity_node_embeddings(self.embedder, nodes) await create_entity_node_embeddings(self.embedder, nodes)
await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder) await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
return AddEpisodeResults(edges=edges, nodes=nodes) return AddTripletResults(episode=None, edges=edges, nodes=nodes)
async def remove_episode(self, episode_uuid: str): async def remove_episode(self, episode_uuid: str):
# Find the episode to be deleted # Find the episode to be deleted

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.20.3" version = "0.20.4"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },