diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 3dae32a0..f643e344 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -627,6 +627,7 @@ class Graphiti: # if group_id is None, use the default group id by the provider group_id = group_id or get_default_group_id(self.driver.provider) validate_group_id(group_id) + await build_dynamic_indexes(self.driver, group_id) # Create default edge type map edge_type_map_default = ( @@ -1008,6 +1009,8 @@ class Graphiti: if edge.fact_embedding is None: await edge.generate_embedding(self.embedder) + await build_dynamic_indexes(self.driver, source_node.group_id) + nodes, uuid_map, _ = await resolve_extracted_nodes( self.clients, [source_node, target_node], diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 9558e711..ef28a94d 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -742,12 +742,17 @@ def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityN attributes.pop('created_at', None) attributes.pop('labels', None) + labels = record.get('labels', []) + group_id = record.get('group_id') + if 'Entity_' + group_id.replace('-', '') in labels: + labels.remove('Entity_' + group_id.replace('-', '')) + entity_node = EntityNode( uuid=record['uuid'], name=record['name'], name_embedding=record.get('name_embedding'), - group_id=record['group_id'], - labels=record['labels'], + group_id=group_id, + labels=labels, created_at=parse_db_date(record['created_at']), # type: ignore summary=record['summary'], attributes=attributes, diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index ea255e82..4324809c 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -325,12 +325,20 @@ async def node_search( search_tasks = [] if NodeSearchMethod.bm25 in config.search_methods: search_tasks.append( - node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit) + node_fulltext_search( + driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes + ) ) if NodeSearchMethod.cosine_similarity in config.search_methods: search_tasks.append( node_similarity_search( - driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score + driver, + query_vector, + search_filter, + group_ids, + 2 * limit, + config.sim_min_score, + config.use_local_indexes, ) ) if NodeSearchMethod.bfs in config.search_methods: @@ -426,7 +434,9 @@ async def episode_search( search_results: list[list[EpisodicNode]] = list( await semaphore_gather( *[ - episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), + episode_fulltext_search( + driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes + ), ] ) ) diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index f24a3f3e..97c12642 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import ( DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA, MAX_SEARCH_DEPTH, + USE_HNSW, ) DEFAULT_SEARCH_LIMIT = 10 @@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel): sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) + use_local_indexes: bool = Field(default=USE_HNSW) class EpisodeSearchConfig(BaseModel): @@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel): sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) + use_local_indexes: bool = Field(default=USE_HNSW) class CommunitySearchConfig(BaseModel): @@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel): sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) + use_local_indexes: bool = Field(default=USE_HNSW) class SearchConfig(BaseModel): diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index b13c3188..05b5004c 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -211,11 +211,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -543,6 +543,7 @@ async def node_fulltext_search( search_filter: SearchFilters, group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, + use_local_indexes: bool = False, ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = fulltext_query(query, group_ids, driver) @@ -576,11 +577,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -600,7 +601,7 @@ async def node_fulltext_search( else: index_name = ( 'node_name_and_summary' - if not USE_HNSW + if not use_local_indexes else 'node_name_and_summary_' + (group_ids[0].replace('-', '') if group_ids is not None else '') ) @@ -637,6 +638,7 @@ async def node_similarity_search( group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, min_score: float = DEFAULT_MIN_SCORE, + use_local_indexes: bool = False, ) -> list[EntityNode]: filter_queries, filter_params = node_search_filter_query_constructor( search_filter, driver.provider @@ -688,11 +690,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -710,7 +712,7 @@ async def node_similarity_search( ) else: return [] - elif driver.provider == GraphProvider.NEO4J and USE_HNSW: + elif driver.provider == GraphProvider.NEO4J and use_local_indexes: index_name = 'group_entity_vector_' + ( group_ids[0].replace('-', '') if group_ids is not None else '' ) @@ -868,6 +870,7 @@ async def episode_fulltext_search( _search_filter: SearchFilters, group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, + use_local_indexes: bool = False, ) -> list[EpisodicNode]: # BM25 search to get top episodes fuzzy_query = fulltext_query(query, group_ids, driver) @@ -919,7 +922,7 @@ async def episode_fulltext_search( else: index_name = ( 'episode_content' - if not USE_HNSW + if not use_local_indexes else 'episode_content_' + (group_ids[0].replace('-', '') if group_ids is not None else '') ) diff --git a/pyproject.toml b/pyproject.toml index 029b05a2..f5eca68a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.19.0pre4" +version = "0.19.0" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/tests/test_graphiti_mock.py b/tests/test_graphiti_mock.py index 9260c747..f2ba89ec 100644 --- a/tests/test_graphiti_mock.py +++ b/tests/test_graphiti_mock.py @@ -159,7 +159,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross entity_node_1 = EntityNode( name='test_entity_1', group_id=group_id, - labels=['Entity', 'Entity_graphiti_test_group', 'Person'], + labels=['Entity', 'Person'], created_at=now, summary='test_entity_1 summary', attributes={'age': 30, 'location': 'New York'}, @@ -169,7 +169,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross entity_node_2 = EntityNode( name='test_entity_2', group_id=group_id, - labels=['Entity', 'Entity_graphiti_test_group', 'Person2'], + labels=['Entity', 'Person2'], created_at=now, summary='test_entity_2 summary', attributes={'age': 25, 'location': 'Los Angeles'}, @@ -179,7 +179,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross entity_node_3 = EntityNode( name='test_entity_3', group_id=group_id, - labels=['Entity', 'City', 'Entity_graphiti_test_group', 'Location'], + labels=['Entity', 'City', 'Location'], created_at=now, summary='test_entity_3 summary', attributes={'age': 25, 'location': 'Los Angeles'}, @@ -189,7 +189,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross entity_node_4 = EntityNode( name='test_entity_4', group_id=group_id, - labels=['Entity', 'Entity_graphiti_test_group'], + labels=['Entity'], created_at=now, summary='test_entity_4 summary', attributes={'age': 25, 'location': 'Los Angeles'}, diff --git a/tests/test_node_int.py b/tests/test_node_int.py index 34226803..7e73b856 100644 --- a/tests/test_node_int.py +++ b/tests/test_node_int.py @@ -45,7 +45,7 @@ def sample_entity_node(): uuid=str(uuid4()), name='Test Entity', group_id=group_id, - labels=['Entity', 'Entity_graphiti_test_group', 'Person'], + labels=['Entity', 'Person'], created_at=created_at, name_embedding=[0.5] * 1024, summary='Entity Summary', diff --git a/uv.lock b/uv.lock index 9feb7cf6..30d4e4fe 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.19.0rc4" +version = "0.19.0" source = { editable = "." } dependencies = [ { name = "diskcache" },