don't return index labels (#887)

* don't return index labels

* update tests
This commit is contained in:
Preston Rasmussen 2025-09-02 12:02:33 -04:00 committed by GitHub
parent 51e880fd57
commit 1460172568
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 55 additions and 30 deletions

View file

@ -627,6 +627,7 @@ class Graphiti:
# if group_id is None, use the default group id by the provider
group_id = group_id or get_default_group_id(self.driver.provider)
validate_group_id(group_id)
await build_dynamic_indexes(self.driver, group_id)
# Create default edge type map
edge_type_map_default = (
@ -1008,6 +1009,8 @@ class Graphiti:
if edge.fact_embedding is None:
await edge.generate_embedding(self.embedder)
await build_dynamic_indexes(self.driver, source_node.group_id)
nodes, uuid_map, _ = await resolve_extracted_nodes(
self.clients,
[source_node, target_node],

View file

@ -742,12 +742,17 @@ def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityN
attributes.pop('created_at', None)
attributes.pop('labels', None)
labels = record.get('labels', [])
group_id = record.get('group_id')
if 'Entity_' + group_id.replace('-', '') in labels:
labels.remove('Entity_' + group_id.replace('-', ''))
entity_node = EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record.get('name_embedding'),
group_id=record['group_id'],
labels=record['labels'],
group_id=group_id,
labels=labels,
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
attributes=attributes,

View file

@ -325,12 +325,20 @@ async def node_search(
search_tasks = []
if NodeSearchMethod.bm25 in config.search_methods:
search_tasks.append(
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
node_fulltext_search(
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
)
)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
node_similarity_search(
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
driver,
query_vector,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
config.use_local_indexes,
)
)
if NodeSearchMethod.bfs in config.search_methods:
@ -426,7 +434,9 @@ async def episode_search(
search_results: list[list[EpisodicNode]] = list(
await semaphore_gather(
*[
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
episode_fulltext_search(
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
),
]
)
)

View file

@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import (
DEFAULT_MIN_SCORE,
DEFAULT_MMR_LAMBDA,
MAX_SEARCH_DEPTH,
USE_HNSW,
)
DEFAULT_SEARCH_LIMIT = 10
@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class EpisodeSearchConfig(BaseModel):
@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class CommunitySearchConfig(BaseModel):
@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
use_local_indexes: bool = Field(default=USE_HNSW)
class SearchConfig(BaseModel):

View file

@ -543,6 +543,7 @@ async def node_fulltext_search(
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
use_local_indexes: bool = False,
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids, driver)
@ -600,7 +601,7 @@ async def node_fulltext_search(
else:
index_name = (
'node_name_and_summary'
if not USE_HNSW
if not use_local_indexes
else 'node_name_and_summary_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)
@ -637,6 +638,7 @@ async def node_similarity_search(
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
use_local_indexes: bool = False,
) -> list[EntityNode]:
filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider
@ -710,7 +712,7 @@ async def node_similarity_search(
)
else:
return []
elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
index_name = 'group_entity_vector_' + (
group_ids[0].replace('-', '') if group_ids is not None else ''
)
@ -868,6 +870,7 @@ async def episode_fulltext_search(
_search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
use_local_indexes: bool = False,
) -> list[EpisodicNode]:
# BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids, driver)
@ -919,7 +922,7 @@ async def episode_fulltext_search(
else:
index_name = (
'episode_content'
if not USE_HNSW
if not use_local_indexes
else 'episode_content_'
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
)

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.19.0pre4"
version = "0.19.0"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -159,7 +159,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
entity_node_1 = EntityNode(
name='test_entity_1',
group_id=group_id,
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
labels=['Entity', 'Person'],
created_at=now,
summary='test_entity_1 summary',
attributes={'age': 30, 'location': 'New York'},
@ -169,7 +169,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
entity_node_2 = EntityNode(
name='test_entity_2',
group_id=group_id,
labels=['Entity', 'Entity_graphiti_test_group', 'Person2'],
labels=['Entity', 'Person2'],
created_at=now,
summary='test_entity_2 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
@ -179,7 +179,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
entity_node_3 = EntityNode(
name='test_entity_3',
group_id=group_id,
labels=['Entity', 'City', 'Entity_graphiti_test_group', 'Location'],
labels=['Entity', 'City', 'Location'],
created_at=now,
summary='test_entity_3 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
@ -189,7 +189,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
entity_node_4 = EntityNode(
name='test_entity_4',
group_id=group_id,
labels=['Entity', 'Entity_graphiti_test_group'],
labels=['Entity'],
created_at=now,
summary='test_entity_4 summary',
attributes={'age': 25, 'location': 'Los Angeles'},

View file

@ -45,7 +45,7 @@ def sample_entity_node():
uuid=str(uuid4()),
name='Test Entity',
group_id=group_id,
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
labels=['Entity', 'Person'],
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Entity Summary',

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.19.0rc4"
version = "0.19.0"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },