add triplet results
This commit is contained in:
parent
469703e180
commit
7ebc5123a5
2 changed files with 68 additions and 62 deletions
|
|
@ -122,19 +122,24 @@ class AddBulkEpisodeResults(BaseModel):
|
|||
community_edges: list[CommunityEdge]
|
||||
|
||||
|
||||
class AddTripletResults(BaseModel):
|
||||
nodes: list[EntityNode]
|
||||
edges: list[EntityEdge]
|
||||
|
||||
|
||||
class Graphiti:
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
llm_client: LLMClient | None = None,
|
||||
embedder: EmbedderClient | None = None,
|
||||
cross_encoder: CrossEncoderClient | None = None,
|
||||
store_raw_episode_content: bool = True,
|
||||
graph_driver: GraphDriver | None = None,
|
||||
max_coroutines: int | None = None,
|
||||
ensure_ascii: bool = False,
|
||||
self,
|
||||
uri: str | None = None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
llm_client: LLMClient | None = None,
|
||||
embedder: EmbedderClient | None = None,
|
||||
cross_encoder: CrossEncoderClient | None = None,
|
||||
store_raw_episode_content: bool = True,
|
||||
graph_driver: GraphDriver | None = None,
|
||||
max_coroutines: int | None = None,
|
||||
ensure_ascii: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a Graphiti instance.
|
||||
|
|
@ -345,11 +350,11 @@ class Graphiti:
|
|||
await build_indices_and_constraints(self.driver, delete_existing)
|
||||
|
||||
async def retrieve_episodes(
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str] | None = None,
|
||||
source: EpisodeType | None = None,
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str] | None = None,
|
||||
source: EpisodeType | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
|
|
@ -379,20 +384,20 @@ class Graphiti:
|
|||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
group_id: str | None = None,
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
previous_episode_uuids: list[str] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
group_id: str | None = None,
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
previous_episode_uuids: list[str] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
) -> AddEpisodeResults:
|
||||
"""
|
||||
Process an episode and update the graph.
|
||||
|
|
@ -582,13 +587,13 @@ class Graphiti:
|
|||
raise e
|
||||
|
||||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
group_id: str | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
group_id: str | None = None,
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
) -> AddBulkEpisodeResults:
|
||||
"""
|
||||
Process multiple episodes in bulk and update the graph.
|
||||
|
|
@ -870,7 +875,7 @@ class Graphiti:
|
|||
raise e
|
||||
|
||||
async def build_communities(
|
||||
self, group_ids: list[str] | None = None
|
||||
self, group_ids: list[str] | None = None
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
"""
|
||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||
|
|
@ -903,12 +908,12 @@ class Graphiti:
|
|||
return community_nodes, community_edges
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
search_filter: SearchFilters | None = None,
|
||||
self,
|
||||
query: str,
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
search_filter: SearchFilters | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
|
|
@ -962,13 +967,13 @@ class Graphiti:
|
|||
return edges
|
||||
|
||||
async def _search(
|
||||
self,
|
||||
query: str,
|
||||
config: SearchConfig,
|
||||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
self,
|
||||
query: str,
|
||||
config: SearchConfig,
|
||||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
) -> SearchResults:
|
||||
"""DEPRECATED"""
|
||||
return await self.search_(
|
||||
|
|
@ -976,13 +981,13 @@ class Graphiti:
|
|||
)
|
||||
|
||||
async def search_(
|
||||
self,
|
||||
query: str,
|
||||
config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
self,
|
||||
query: str,
|
||||
config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
) -> SearchResults:
|
||||
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
||||
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
||||
|
|
@ -1015,8 +1020,9 @@ class Graphiti:
|
|||
|
||||
return SearchResults(edges=edges, nodes=nodes)
|
||||
|
||||
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge,
|
||||
target_node: EntityNode) -> AddEpisodeResults:
|
||||
async def add_triplet(
|
||||
self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
||||
) -> AddTripletResults:
|
||||
if source_node.name_embedding is None:
|
||||
await source_node.generate_name_embedding(self.embedder)
|
||||
if target_node.name_embedding is None:
|
||||
|
|
@ -1060,7 +1066,7 @@ class Graphiti:
|
|||
await create_entity_node_embeddings(self.embedder, nodes)
|
||||
|
||||
await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
|
||||
return AddEpisodeResults(edges=edges, nodes=nodes)
|
||||
return AddTripletResults(episode=None, edges=edges, nodes=nodes)
|
||||
|
||||
async def remove_episode(self, episode_uuid: str):
|
||||
# Find the episode to be deleted
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.20.3"
|
||||
version = "0.20.4"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue