From 49aeaf75f2a9bf6fbea56c8bbfcb1b5f52d5b5a1 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 8 Oct 2024 13:55:10 -0400 Subject: [PATCH] Add mmr reranking (#180) * mmr start * add mmr function * normalize * add mmr options to search * update communities * build communities * format * clean up normalization * normalize in mmr * update --- graphiti_core/embedder/openai.py | 2 +- graphiti_core/embedder/voyage.py | 2 +- graphiti_core/graphiti.py | 17 ++- graphiti_core/helpers.py | 13 ++ graphiti_core/search/search.py | 43 +++++- graphiti_core/search/search_config.py | 10 ++ graphiti_core/search/search_config_recipes.py | 40 ++++++ graphiti_core/search/search_utils.py | 132 +++++++++++------- .../utils/maintenance/community_operations.py | 38 ++--- pyproject.toml | 2 +- tests/test_graphiti_int.py | 4 +- 11 files changed, 215 insertions(+), 88 deletions(-) diff --git a/graphiti_core/embedder/openai.py b/graphiti_core/embedder/openai.py index a209dba1..1e961d6e 100644 --- a/graphiti_core/embedder/openai.py +++ b/graphiti_core/embedder/openai.py @@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient): self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) async def create( - self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] ) -> list[float]: result = await self.client.embeddings.create(input=input, model=self.config.embedding_model) return result.data[0].embedding[: self.config.embedding_dim] diff --git a/graphiti_core/embedder/voyage.py b/graphiti_core/embedder/voyage.py index f0fca309..4c2d2509 100644 --- a/graphiti_core/embedder/voyage.py +++ b/graphiti_core/embedder/voyage.py @@ -41,7 +41,7 @@ class VoyageAIEmbedder(EmbedderClient): self.client = voyageai.AsyncClient(api_key=config.api_key) async def create( - self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] ) -> list[float]: result = await self.client.embed(input, model=self.config.embedding_model) return result.embeddings[0][: self.config.embedding_dim] diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 07507e0d..2bf1480b 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -26,7 +26,7 @@ from pydantic import BaseModel from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.llm_client import LLMClient, OpenAIClient -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults from graphiti_core.search.search_config_recipes import ( @@ -576,11 +576,20 @@ class Graphiti: except Exception as e: raise e - async def build_communities(self): + async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]: + """ + Use a community clustering algorithm to find communities of nodes. Create community nodes summarising + the content of these communities. + ---------- + query : list[str] | None + Optional. Create communities only for the listed group_ids. If blank the entire graph will be used. + """ # Clear existing communities await remove_communities(self.driver) - community_nodes, community_edges = await build_communities(self.driver, self.llm_client) + community_nodes, community_edges = await build_communities( + self.driver, self.llm_client, group_ids + ) await asyncio.gather( *[node.generate_name_embedding(self.embedder) for node in community_nodes] @@ -589,6 +598,8 @@ class Graphiti: await asyncio.gather(*[node.save(self.driver) for node in community_nodes]) await asyncio.gather(*[edge.save(self.driver) for edge in community_edges]) + return community_nodes + async def search( self, query: str, diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 314babe5..d04e94d2 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -16,6 +16,7 @@ limitations under the License. from datetime import datetime +import numpy as np from neo4j import time as neo4j_time @@ -52,3 +53,15 @@ def lucene_sanitize(query: str) -> str: sanitized = query.translate(escape_map) return sanitized + + +def normalize_l2(embedding: list[float]) -> list[float]: + embedding_array = np.array(embedding) + if embedding_array.ndim == 1: + norm = np.linalg.norm(embedding_array) + if norm == 0: + return embedding_array.tolist() + return (embedding_array / norm).tolist() + else: + norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True) + return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist() diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 4d4f34de..e6e81b18 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -42,6 +42,7 @@ from graphiti_core.search.search_utils import ( edge_fulltext_search, edge_similarity_search, episode_mentions_reranker, + maximal_marginal_relevance, node_distance_reranker, node_fulltext_search, node_similarity_search, @@ -117,12 +118,14 @@ async def edge_search( if config is None: return [] + query_vector = await embedder.create(input=[query]) + search_results: list[list[EntityEdge]] = list( await asyncio.gather( *[ edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit), edge_similarity_search( - driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit + driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score ), ] ) @@ -135,6 +138,15 @@ async def edge_search( search_result_uuids = [[edge.uuid for edge in result] for result in search_results] reranked_uuids = rrf(search_result_uuids) + elif config.reranker == EdgeReranker.mmr: + search_result_uuids_and_vectors = [ + (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024) + for result in search_results + for edge in result + ] + reranked_uuids = maximal_marginal_relevance( + query_vector, search_result_uuids_and_vectors, config.mmr_lambda + ) elif config.reranker == EdgeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') @@ -175,12 +187,14 @@ async def node_search( if config is None: return [] + query_vector = await embedder.create(input=[query]) + search_results: list[list[EntityNode]] = list( await asyncio.gather( *[ node_fulltext_search(driver, query, group_ids, 2 * limit), node_similarity_search( - driver, await embedder.create(input=[query]), group_ids, 2 * limit + driver, query_vector, group_ids, 2 * limit, config.sim_min_score ), ] ) @@ -192,6 +206,15 @@ async def node_search( reranked_uuids: list[str] = [] if config.reranker == NodeReranker.rrf: reranked_uuids = rrf(search_result_uuids) + elif config.reranker == NodeReranker.mmr: + search_result_uuids_and_vectors = [ + (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024) + for result in search_results + for node in result + ] + reranked_uuids = maximal_marginal_relevance( + query_vector, search_result_uuids_and_vectors, config.mmr_lambda + ) elif config.reranker == NodeReranker.episode_mentions: reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) elif config.reranker == NodeReranker.node_distance: @@ -217,12 +240,14 @@ async def community_search( if config is None: return [] + query_vector = await embedder.create(input=[query]) + search_results: list[list[CommunityNode]] = list( await asyncio.gather( *[ community_fulltext_search(driver, query, group_ids, 2 * limit), community_similarity_search( - driver, await embedder.create(input=[query]), group_ids, 2 * limit + driver, query_vector, group_ids, 2 * limit, config.sim_min_score ), ] ) @@ -236,6 +261,18 @@ async def community_search( reranked_uuids: list[str] = [] if config.reranker == CommunityReranker.rrf: reranked_uuids = rrf(search_result_uuids) + elif config.reranker == CommunityReranker.mmr: + search_result_uuids_and_vectors = [ + ( + community.uuid, + community.name_embedding if community.name_embedding is not None else [0.0] * 1024, + ) + for result in search_results + for community in result + ] + reranked_uuids = maximal_marginal_relevance( + query_vector, search_result_uuids_and_vectors, config.mmr_lambda + ) reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index 21acca51..badee7c6 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -20,6 +20,7 @@ from pydantic import BaseModel, Field from graphiti_core.edges import EntityEdge from graphiti_core.nodes import CommunityNode, EntityNode +from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA DEFAULT_SEARCH_LIMIT = 10 @@ -43,31 +44,40 @@ class EdgeReranker(Enum): rrf = 'reciprocal_rank_fusion' node_distance = 'node_distance' episode_mentions = 'episode_mentions' + mmr = 'mmr' class NodeReranker(Enum): rrf = 'reciprocal_rank_fusion' node_distance = 'node_distance' episode_mentions = 'episode_mentions' + mmr = 'mmr' class CommunityReranker(Enum): rrf = 'reciprocal_rank_fusion' + mmr = 'mmr' class EdgeSearchConfig(BaseModel): search_methods: list[EdgeSearchMethod] reranker: EdgeReranker = Field(default=EdgeReranker.rrf) + sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) + mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) class NodeSearchConfig(BaseModel): search_methods: list[NodeSearchMethod] reranker: NodeReranker = Field(default=NodeReranker.rrf) + sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) + mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) class CommunitySearchConfig(BaseModel): search_methods: list[CommunitySearchMethod] reranker: CommunityReranker = Field(default=CommunityReranker.rrf) + sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) + mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) class SearchConfig(BaseModel): diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py index 8396307b..2264cde0 100644 --- a/graphiti_core/search/search_config_recipes.py +++ b/graphiti_core/search/search_config_recipes.py @@ -43,6 +43,22 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig( ), ) +# Performs a hybrid search with mmr reranking over edges, nodes, and communities +COMBINED_HYBRID_SEARCH_MMR = SearchConfig( + edge_config=EdgeSearchConfig( + search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], + reranker=EdgeReranker.mmr, + ), + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.mmr, + ), + community_config=CommunitySearchConfig( + search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], + reranker=CommunityReranker.mmr, + ), +) + # performs a hybrid search over edges with rrf reranking EDGE_HYBRID_SEARCH_RRF = SearchConfig( edge_config=EdgeSearchConfig( @@ -51,6 +67,14 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig( ) ) +# performs a hybrid search over edges with mmr reranking +EDGE_HYBRID_SEARCH_mmr = SearchConfig( + edge_config=EdgeSearchConfig( + search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], + reranker=EdgeReranker.mmr, + ) +) + # performs a hybrid search over edges with node distance reranking EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( edge_config=EdgeSearchConfig( @@ -75,6 +99,14 @@ NODE_HYBRID_SEARCH_RRF = SearchConfig( ) ) +# performs a hybrid search over nodes with mmr reranking +NODE_HYBRID_SEARCH_MMR = SearchConfig( + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.mmr, + ) +) + # performs a hybrid search over nodes with node distance reranking NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( node_config=NodeSearchConfig( @@ -98,3 +130,11 @@ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig( reranker=CommunityReranker.rrf, ) ) + +# performs a hybrid search over communities with mmr reranking +COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig( + community_config=CommunitySearchConfig( + search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], + reranker=CommunityReranker.mmr, + ) +) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 9af18d81..bc1e825f 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -19,10 +19,11 @@ import logging from collections import defaultdict from time import time +import numpy as np from neo4j import AsyncDriver, Query from graphiti_core.edges import EntityEdge, get_entity_edge_from_record -from graphiti_core.helpers import lucene_sanitize +from graphiti_core.helpers import lucene_sanitize, normalize_l2 from graphiti_core.nodes import ( CommunityNode, EntityNode, @@ -34,6 +35,8 @@ from graphiti_core.nodes import ( logger = logging.getLogger(__name__) RELEVANT_SCHEMA_LIMIT = 3 +DEFAULT_MIN_SCORE = 0.6 +DEFAULT_MMR_LAMBDA = 0.5 def fulltext_query(query: str, group_ids: list[str] | None = None): @@ -53,10 +56,10 @@ def fulltext_query(query: str, group_ids: list[str] | None = None): async def get_episodes_by_mentions( - driver: AsyncDriver, - nodes: list[EntityNode], - edges: list[EntityEdge], - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + nodes: list[EntityNode], + edges: list[EntityEdge], + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EpisodicNode]: episode_uuids: list[str] = [] for edge in edges: @@ -68,7 +71,7 @@ async def get_episodes_by_mentions( async def get_mentioned_nodes( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] records, _, _ = await driver.execute_query( @@ -91,7 +94,7 @@ async def get_mentioned_nodes( async def get_communities_by_nodes( - driver: AsyncDriver, nodes: list[EntityNode] + driver: AsyncDriver, nodes: list[EntityNode] ) -> list[CommunityNode]: node_uuids = [node.uuid for node in nodes] records, _, _ = await driver.execute_query( @@ -114,12 +117,12 @@ async def get_communities_by_nodes( async def edge_fulltext_search( - driver: AsyncDriver, - query: str, - source_node_uuid: str | None, - target_node_uuid: str | None, - group_ids: list[str] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + query: str, + source_node_uuid: str | None, + target_node_uuid: str | None, + group_ids: list[str] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # fulltext search over facts fuzzy_query = fulltext_query(query, group_ids) @@ -159,12 +162,13 @@ async def edge_fulltext_search( async def edge_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - source_node_uuid: str | None, - target_node_uuid: str | None, - group_ids: list[str] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + group_ids: list[str] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, + min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityEdge]: # vector similarity search over embedded facts query = Query(""" @@ -174,7 +178,7 @@ async def edge_similarity_search( AND ($source_uuid IS NULL OR n.uuid = $source_uuid) AND ($target_uuid IS NULL OR m.uuid = $target_uuid) WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score - WHERE score > 0.6 + WHERE score > $min_score RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -199,6 +203,7 @@ async def edge_similarity_search( target_uuid=target_node_uuid, group_ids=group_ids, limit=limit, + min_score=min_score, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -207,10 +212,10 @@ async def edge_similarity_search( async def node_fulltext_search( - driver: AsyncDriver, - query: str, - group_ids: list[str] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + query: str, + group_ids: list[str] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = fulltext_query(query, group_ids) @@ -239,10 +244,11 @@ async def node_fulltext_search( async def node_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - group_ids: list[str] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + search_vector: list[float], + group_ids: list[str] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, + min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -251,7 +257,7 @@ async def node_similarity_search( MATCH (n:Entity) WHERE $group_ids IS NULL OR n.group_id IN $group_ids WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score - WHERE score > 0.6 + WHERE score > $min_score RETURN n.uuid As uuid, n.group_id AS group_id, @@ -265,6 +271,7 @@ async def node_similarity_search( search_vector=search_vector, group_ids=group_ids, limit=limit, + min_score=min_score, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -272,10 +279,10 @@ async def node_similarity_search( async def community_fulltext_search( - driver: AsyncDriver, - query: str, - group_ids: list[str] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + query: str, + group_ids: list[str] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[CommunityNode]: # BM25 search to get top communities fuzzy_query = fulltext_query(query, group_ids) @@ -304,10 +311,11 @@ async def community_fulltext_search( async def community_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - group_ids: list[str] | None = None, - limit=RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + search_vector: list[float], + group_ids: list[str] | None = None, + limit=RELEVANT_SCHEMA_LIMIT, + min_score=DEFAULT_MIN_SCORE, ) -> list[CommunityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -316,7 +324,7 @@ async def community_similarity_search( MATCH (comm:Community) WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score - WHERE score > 0.6 + WHERE score > $min_score RETURN comm.uuid As uuid, comm.group_id AS group_id, @@ -330,6 +338,7 @@ async def community_similarity_search( search_vector=search_vector, group_ids=group_ids, limit=limit, + min_score=min_score, ) communities = [get_community_node_from_record(record) for record in records] @@ -337,11 +346,11 @@ async def community_similarity_search( async def hybrid_node_search( - queries: list[str], - embeddings: list[list[float]], - driver: AsyncDriver, - group_ids: list[str] | None = None, - limit: int = RELEVANT_SCHEMA_LIMIT, + queries: list[str], + embeddings: list[list[float]], + driver: AsyncDriver, + group_ids: list[str] | None = None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Perform a hybrid search for nodes using both text queries and embeddings. @@ -404,8 +413,8 @@ async def hybrid_node_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: """ Retrieve relevant nodes based on the provided list of EntityNodes. @@ -442,11 +451,11 @@ async def get_relevant_nodes( async def get_relevant_edges( - driver: AsyncDriver, - edges: list[EntityEdge], - source_node_uuid: str | None, - target_node_uuid: str | None, - limit: int = RELEVANT_SCHEMA_LIMIT, + driver: AsyncDriver, + edges: list[EntityEdge], + source_node_uuid: str | None, + target_node_uuid: str | None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -503,7 +512,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str + driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str ) -> list[str]: # filter out node_uuid center node node uuid filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids)) @@ -570,3 +579,24 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) return sorted_uuids + + +def maximal_marginal_relevance( + query_vector: list[float], + candidates: list[tuple[str, list[float]]], + mmr_lambda: float = DEFAULT_MMR_LAMBDA, +): + candidates_with_mmr: list[tuple[str, float]] = [] + for candidate in candidates: + max_sim = max( + [ + np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) + for c in candidates + ] + ) + mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim + candidates_with_mmr.append((candidate[0], mmr)) + + candidates_with_mmr.sort(reverse=True, key=lambda c: c[1]) + + return [candidate[0] for candidate in candidates_with_mmr] diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index eb637617..599ce9ab 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -15,7 +15,6 @@ from graphiti_core.utils.maintenance.edge_operations import build_community_edge MAX_COMMUNITY_BUILD_CONCURRENCY = 10 - logger = logging.getLogger(__name__) @@ -24,31 +23,20 @@ class Neighbor(BaseModel): edge_count: int -async def build_community_projection(driver: AsyncDriver) -> str: - records, _, _ = await driver.execute_query(""" - CALL gds.graph.project("communities", "Entity", - {RELATES_TO: { - type: "RELATES_TO", - orientation: "UNDIRECTED", - properties: {weight: {property: "*", aggregation: "COUNT"}} - }} - ) - YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges - """) - - return records[0]['graph'] - - -async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]: +async def get_community_clusters( + driver: AsyncDriver, group_ids: list[str] | None +) -> list[list[EntityNode]]: community_clusters: list[list[EntityNode]] = [] - group_id_values, _, _ = await driver.execute_query(""" - MATCH (n:Entity WHERE n.group_id IS NOT NULL) - RETURN - collect(DISTINCT n.group_id) AS group_ids - """) + if group_ids is None: + group_id_values, _, _ = await driver.execute_query(""" + MATCH (n:Entity WHERE n.group_id IS NOT NULL) + RETURN + collect(DISTINCT n.group_id) AS group_ids + """) + + group_ids = group_id_values[0]['group_ids'] - group_ids = group_id_values[0]['group_ids'] for group_id in group_ids: projection: dict[str, list[Neighbor]] = {} nodes = await EntityNode.get_by_group_ids(driver, [group_id]) @@ -197,9 +185,9 @@ async def build_community( async def build_communities( - driver: AsyncDriver, llm_client: LLMClient + driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None ) -> tuple[list[CommunityNode], list[CommunityEdge]]: - community_clusters = await get_community_clusters(driver) + community_clusters = await get_community_clusters(driver, group_ids) semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY) diff --git a/pyproject.toml b/pyproject.toml index b39a8499..4e06c34a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.3.8" +version = "0.3.9" description = "A temporal graph building library" authors = [ "Paul Paliychuk ", diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 4c0a34b5..30e886c2 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -85,9 +85,7 @@ async def test_graphiti_init(): logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) - results = await graphiti._search( - 'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None - ) + results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None) pretty_results = { 'edges': [edge.fact for edge in results.edges], 'nodes': [node.name for node in results.nodes],