don't return index labels (#887)
* don't return index labels * update tests
This commit is contained in:
parent
51e880fd57
commit
1460172568
9 changed files with 55 additions and 30 deletions
|
|
@ -627,6 +627,7 @@ class Graphiti:
|
||||||
# if group_id is None, use the default group id by the provider
|
# if group_id is None, use the default group id by the provider
|
||||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||||
validate_group_id(group_id)
|
validate_group_id(group_id)
|
||||||
|
await build_dynamic_indexes(self.driver, group_id)
|
||||||
|
|
||||||
# Create default edge type map
|
# Create default edge type map
|
||||||
edge_type_map_default = (
|
edge_type_map_default = (
|
||||||
|
|
@ -1008,6 +1009,8 @@ class Graphiti:
|
||||||
if edge.fact_embedding is None:
|
if edge.fact_embedding is None:
|
||||||
await edge.generate_embedding(self.embedder)
|
await edge.generate_embedding(self.embedder)
|
||||||
|
|
||||||
|
await build_dynamic_indexes(self.driver, source_node.group_id)
|
||||||
|
|
||||||
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
||||||
self.clients,
|
self.clients,
|
||||||
[source_node, target_node],
|
[source_node, target_node],
|
||||||
|
|
|
||||||
|
|
@ -742,12 +742,17 @@ def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityN
|
||||||
attributes.pop('created_at', None)
|
attributes.pop('created_at', None)
|
||||||
attributes.pop('labels', None)
|
attributes.pop('labels', None)
|
||||||
|
|
||||||
|
labels = record.get('labels', [])
|
||||||
|
group_id = record.get('group_id')
|
||||||
|
if 'Entity_' + group_id.replace('-', '') in labels:
|
||||||
|
labels.remove('Entity_' + group_id.replace('-', ''))
|
||||||
|
|
||||||
entity_node = EntityNode(
|
entity_node = EntityNode(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
name_embedding=record.get('name_embedding'),
|
name_embedding=record.get('name_embedding'),
|
||||||
group_id=record['group_id'],
|
group_id=group_id,
|
||||||
labels=record['labels'],
|
labels=labels,
|
||||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
|
|
|
||||||
|
|
@ -325,12 +325,20 @@ async def node_search(
|
||||||
search_tasks = []
|
search_tasks = []
|
||||||
if NodeSearchMethod.bm25 in config.search_methods:
|
if NodeSearchMethod.bm25 in config.search_methods:
|
||||||
search_tasks.append(
|
search_tasks.append(
|
||||||
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
node_fulltext_search(
|
||||||
|
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||||
search_tasks.append(
|
search_tasks.append(
|
||||||
node_similarity_search(
|
node_similarity_search(
|
||||||
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
driver,
|
||||||
|
query_vector,
|
||||||
|
search_filter,
|
||||||
|
group_ids,
|
||||||
|
2 * limit,
|
||||||
|
config.sim_min_score,
|
||||||
|
config.use_local_indexes,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if NodeSearchMethod.bfs in config.search_methods:
|
if NodeSearchMethod.bfs in config.search_methods:
|
||||||
|
|
@ -426,7 +434,9 @@ async def episode_search(
|
||||||
search_results: list[list[EpisodicNode]] = list(
|
search_results: list[list[EpisodicNode]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
episode_fulltext_search(
|
||||||
|
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import (
|
||||||
DEFAULT_MIN_SCORE,
|
DEFAULT_MIN_SCORE,
|
||||||
DEFAULT_MMR_LAMBDA,
|
DEFAULT_MMR_LAMBDA,
|
||||||
MAX_SEARCH_DEPTH,
|
MAX_SEARCH_DEPTH,
|
||||||
|
USE_HNSW,
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_SEARCH_LIMIT = 10
|
DEFAULT_SEARCH_LIMIT = 10
|
||||||
|
|
@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel):
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
|
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSearchConfig(BaseModel):
|
class EpisodeSearchConfig(BaseModel):
|
||||||
|
|
@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel):
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
|
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||||
|
|
||||||
|
|
||||||
class CommunitySearchConfig(BaseModel):
|
class CommunitySearchConfig(BaseModel):
|
||||||
|
|
@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel):
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
|
use_local_indexes: bool = Field(default=USE_HNSW)
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(BaseModel):
|
class SearchConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -211,11 +211,11 @@ async def edge_fulltext_search(
|
||||||
# Match the edge ids and return the values
|
# Match the edge ids and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as id
|
UNWIND $ids as id
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
|
|
@ -543,6 +543,7 @@ async def node_fulltext_search(
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
use_local_indexes: bool = False,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||||
|
|
@ -576,11 +577,11 @@ async def node_fulltext_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE n.uuid=i.id
|
WHERE n.uuid=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -600,7 +601,7 @@ async def node_fulltext_search(
|
||||||
else:
|
else:
|
||||||
index_name = (
|
index_name = (
|
||||||
'node_name_and_summary'
|
'node_name_and_summary'
|
||||||
if not USE_HNSW
|
if not use_local_indexes
|
||||||
else 'node_name_and_summary_'
|
else 'node_name_and_summary_'
|
||||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||||
)
|
)
|
||||||
|
|
@ -637,6 +638,7 @@ async def node_similarity_search(
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
|
use_local_indexes: bool = False,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
filter_queries, filter_params = node_search_filter_query_constructor(
|
filter_queries, filter_params = node_search_filter_query_constructor(
|
||||||
search_filter, driver.provider
|
search_filter, driver.provider
|
||||||
|
|
@ -688,11 +690,11 @@ async def node_similarity_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE id(n)=i.id
|
WHERE id(n)=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -710,7 +712,7 @@ async def node_similarity_search(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
|
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
|
||||||
index_name = 'group_entity_vector_' + (
|
index_name = 'group_entity_vector_' + (
|
||||||
group_ids[0].replace('-', '') if group_ids is not None else ''
|
group_ids[0].replace('-', '') if group_ids is not None else ''
|
||||||
)
|
)
|
||||||
|
|
@ -868,6 +870,7 @@ async def episode_fulltext_search(
|
||||||
_search_filter: SearchFilters,
|
_search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
use_local_indexes: bool = False,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
# BM25 search to get top episodes
|
# BM25 search to get top episodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||||
|
|
@ -919,7 +922,7 @@ async def episode_fulltext_search(
|
||||||
else:
|
else:
|
||||||
index_name = (
|
index_name = (
|
||||||
'episode_content'
|
'episode_content'
|
||||||
if not USE_HNSW
|
if not use_local_indexes
|
||||||
else 'episode_content_'
|
else 'episode_content_'
|
||||||
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.19.0pre4"
|
version = "0.19.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
||||||
entity_node_1 = EntityNode(
|
entity_node_1 = EntityNode(
|
||||||
name='test_entity_1',
|
name='test_entity_1',
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
|
labels=['Entity', 'Person'],
|
||||||
created_at=now,
|
created_at=now,
|
||||||
summary='test_entity_1 summary',
|
summary='test_entity_1 summary',
|
||||||
attributes={'age': 30, 'location': 'New York'},
|
attributes={'age': 30, 'location': 'New York'},
|
||||||
|
|
@ -169,7 +169,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
||||||
entity_node_2 = EntityNode(
|
entity_node_2 = EntityNode(
|
||||||
name='test_entity_2',
|
name='test_entity_2',
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person2'],
|
labels=['Entity', 'Person2'],
|
||||||
created_at=now,
|
created_at=now,
|
||||||
summary='test_entity_2 summary',
|
summary='test_entity_2 summary',
|
||||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||||
|
|
@ -179,7 +179,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
||||||
entity_node_3 = EntityNode(
|
entity_node_3 = EntityNode(
|
||||||
name='test_entity_3',
|
name='test_entity_3',
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
labels=['Entity', 'City', 'Entity_graphiti_test_group', 'Location'],
|
labels=['Entity', 'City', 'Location'],
|
||||||
created_at=now,
|
created_at=now,
|
||||||
summary='test_entity_3 summary',
|
summary='test_entity_3 summary',
|
||||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||||
|
|
@ -189,7 +189,7 @@ async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross
|
||||||
entity_node_4 = EntityNode(
|
entity_node_4 = EntityNode(
|
||||||
name='test_entity_4',
|
name='test_entity_4',
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
labels=['Entity', 'Entity_graphiti_test_group'],
|
labels=['Entity'],
|
||||||
created_at=now,
|
created_at=now,
|
||||||
summary='test_entity_4 summary',
|
summary='test_entity_4 summary',
|
||||||
attributes={'age': 25, 'location': 'Los Angeles'},
|
attributes={'age': 25, 'location': 'Los Angeles'},
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ def sample_entity_node():
|
||||||
uuid=str(uuid4()),
|
uuid=str(uuid4()),
|
||||||
name='Test Entity',
|
name='Test Entity',
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
labels=['Entity', 'Entity_graphiti_test_group', 'Person'],
|
labels=['Entity', 'Person'],
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
name_embedding=[0.5] * 1024,
|
name_embedding=[0.5] * 1024,
|
||||||
summary='Entity Summary',
|
summary='Entity Summary',
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.19.0rc4"
|
version = "0.19.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue