don't return index labels (#887)
* don't return index labels * update tests
This commit is contained in:
parent
51e880fd57
commit
1460172568
9 changed files with 55 additions and 30 deletions
|
|
@ -627,6 +627,7 @@ class Graphiti:
|
|||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
validate_group_id(group_id)
|
||||
await build_dynamic_indexes(self.driver, group_id)
|
||||
|
||||
# Create default edge type map
|
||||
edge_type_map_default = (
|
||||
|
|
@ -1008,6 +1009,8 @@ class Graphiti:
|
|||
if edge.fact_embedding is None:
|
||||
await edge.generate_embedding(self.embedder)
|
||||
|
||||
await build_dynamic_indexes(self.driver, source_node.group_id)
|
||||
|
||||
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
||||
self.clients,
|
||||
[source_node, target_node],
|
||||
|
|
|
|||
|
|
@ -742,12 +742,17 @@ def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityN
|
|||
attributes.pop('created_at', None)
|
||||
attributes.pop('labels', None)
|
||||
|
||||
labels = record.get('labels', [])
|
||||
group_id = record.get('group_id')
|
||||
if 'Entity_' + group_id.replace('-', '') in labels:
|
||||
labels.remove('Entity_' + group_id.replace('-', ''))
|
||||
|
||||
entity_node = EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
name_embedding=record.get('name_embedding'),
|
||||
group_id=record['group_id'],
|
||||
labels=record['labels'],
|
||||
group_id=group_id,
|
||||
labels=labels,
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
summary=record['summary'],
|
||||
attributes=attributes,
|
||||
|
|
|
|||
|
|
@ -325,12 +325,20 @@ async def node_search(
|
|||
search_tasks = []
|
||||
if NodeSearchMethod.bm25 in config.search_methods:
|
||||
search_tasks.append(
|
||||
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
||||
node_fulltext_search(
|
||||
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
||||
)
|
||||
)
|
||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_tasks.append(
|
||||
node_similarity_search(
|
||||
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
||||
driver,
|
||||
query_vector,
|
||||
search_filter,
|
||||
group_ids,
|
||||
2 * limit,
|
||||
config.sim_min_score,
|
||||
config.use_local_indexes,
|
||||
)
|
||||
)
|
||||
if NodeSearchMethod.bfs in config.search_methods:
|
||||
|
|
@ -426,7 +434,9 @@ async def episode_search(
|
|||
search_results: list[list[EpisodicNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
||||
episode_fulltext_search(
|
||||
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import (
|
|||
DEFAULT_MIN_SCORE,
|
||||
DEFAULT_MMR_LAMBDA,
|
||||
MAX_SEARCH_DEPTH,
|
||||
USE_HNSW,
|
||||
)
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 10
|
||||
|
|
@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel):
|
|||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||
|
||||
|
||||
class EpisodeSearchConfig(BaseModel):
|
||||
|
|
@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel):
|
|||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||
|
||||
|
||||
class CommunitySearchConfig(BaseModel):
|
||||
|
|
@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel):
|
|||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -543,6 +543,7 @@ async def node_fulltext_search(
|
|||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
use_local_indexes: bool = False,
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||
|
|
@ -600,7 +601,7 @@ async def node_fulltext_search(
|
|||
else:
|
||||
index_name = (
|
||||
'node_name_and_summary'
|
||||
if not USE_HNSW
|
||||
if not use_local_indexes
|
||||
else 'node_name_and_summary_'
|
||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||
)
|
||||
|
|
@ -637,6 +638,7 @@ async def node_similarity_search(
|
|||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
use_local_indexes: bool = False,
|
||||
) -> list[EntityNode]:
|
||||
filter_queries, filter_params = node_search_filter_query_constructor(
|
||||
search_filter, driver.provider
|
||||
|
|
@ -710,7 +712,7 @@ async def node_similarity_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
|
||||
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
|
||||
index_name = 'group_entity_vector_' + (
|
||||
group_ids[0].replace('-', '') if group_ids is not None else ''
|
||||
)
|
||||
|
|
@ -868,6 +870,7 @@ async def episode_fulltext_search(
|
|||
_search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
use_local_indexes: bool = False,
|
||||
) -> list[EpisodicNode]:
|
||||
# BM25 search to get top episodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||
|
|
@ -919,7 +922,7 @@ async def episode_fulltext_search(
|
|||
else:
|
||||
index_name = (
|
||||
'episode_content'
|
||||
if not USE_HNSW
|
||||
if not use_local_indexes
|
||||
else 'episode_content_'
|
||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.19.0pre4"
|
||||
version = "0.19.0"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
|||
entity_node_1 = EntityNode(
|
||||
name='test_entity_1',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
|
||||
labels=['Entity', 'Person'],
|
||||
created_at=now,
|
||||
summary='test_entity_1 summary',
|
||||
attributes={'age': 30, 'location': 'New York'},
|
||||
|
|
@ -169,7 +169,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
|||
entity_node_2 = EntityNode(
|
||||
name='test_entity_2',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person2'],
|
||||
labels=['Entity', 'Person2'],
|
||||
created_at=now,
|
||||
summary='test_entity_2 summary',
|
||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||
|
|
@ -179,7 +179,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
|||
entity_node_3 = EntityNode(
|
||||
name='test_entity_3',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'City', 'Entity_graphiti_test_group', 'Location'],
|
||||
labels=['Entity', 'City', 'Location'],
|
||||
created_at=now,
|
||||
summary='test_entity_3 summary',
|
||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||
|
|
@ -189,7 +189,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
|||
entity_node_4 = EntityNode(
|
||||
name='test_entity_4',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'Entity_graphiti_test_group'],
|
||||
labels=['Entity'],
|
||||
created_at=now,
|
||||
summary='test_entity_4 summary',
|
||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def sample_entity_node():
|
|||
uuid=str(uuid4()),
|
||||
name='Test Entity',
|
||||
group_id=group_id,
|
||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
|
||||
labels=['Entity', 'Person'],
|
||||
created_at=created_at,
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Entity Summary',
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.19.0rc4"
|
||||
version = "0.19.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue