Add mmr reranking (#180)

* mmr start

* add mmr function

* normalize

* add mmr options to search

* update communities

* build communities

* format

* clean up normalization

* normalize in mmr

* update
This commit is contained in:
Preston Rasmussen 2024-10-08 13:55:10 -04:00 committed by GitHub
parent 5508dba1b3
commit 49aeaf75f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 215 additions and 88 deletions

View file

@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def create( async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]: ) -> list[float]:
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model) result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
return result.data[0].embedding[: self.config.embedding_dim] return result.data[0].embedding[: self.config.embedding_dim]

View file

@ -41,7 +41,7 @@ class VoyageAIEmbedder(EmbedderClient):
self.client = voyageai.AsyncClient(api_key=config.api_key) self.client = voyageai.AsyncClient(api_key=config.api_key)
async def create( async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]: ) -> list[float]:
result = await self.client.embed(input, model=self.config.embedding_model) result = await self.client.embed(input, model=self.config.embedding_model)
return result.embeddings[0][: self.config.embedding_dim] return result.embeddings[0][: self.config.embedding_dim]

View file

@ -26,7 +26,7 @@ from pydantic import BaseModel
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search import SearchConfig, search
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
from graphiti_core.search.search_config_recipes import ( from graphiti_core.search.search_config_recipes import (
@ -576,11 +576,20 @@ class Graphiti:
except Exception as e: except Exception as e:
raise e raise e
async def build_communities(self): async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
"""
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
the content of these communities.
----------
query : list[str] | None
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
"""
# Clear existing communities # Clear existing communities
await remove_communities(self.driver) await remove_communities(self.driver)
community_nodes, community_edges = await build_communities(self.driver, self.llm_client) community_nodes, community_edges = await build_communities(
self.driver, self.llm_client, group_ids
)
await asyncio.gather( await asyncio.gather(
*[node.generate_name_embedding(self.embedder) for node in community_nodes] *[node.generate_name_embedding(self.embedder) for node in community_nodes]
@ -589,6 +598,8 @@ class Graphiti:
await asyncio.gather(*[node.save(self.driver) for node in community_nodes]) await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges]) await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
return community_nodes
async def search( async def search(
self, self,
query: str, query: str,

View file

@ -16,6 +16,7 @@ limitations under the License.
from datetime import datetime from datetime import datetime
import numpy as np
from neo4j import time as neo4j_time from neo4j import time as neo4j_time
@ -52,3 +53,15 @@ def lucene_sanitize(query: str) -> str:
sanitized = query.translate(escape_map) sanitized = query.translate(escape_map)
return sanitized return sanitized
def normalize_l2(embedding: list[float]) -> list[float]:
embedding_array = np.array(embedding)
if embedding_array.ndim == 1:
norm = np.linalg.norm(embedding_array)
if norm == 0:
return embedding_array.tolist()
return (embedding_array / norm).tolist()
else:
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()

View file

@ -42,6 +42,7 @@ from graphiti_core.search.search_utils import (
edge_fulltext_search, edge_fulltext_search,
edge_similarity_search, edge_similarity_search,
episode_mentions_reranker, episode_mentions_reranker,
maximal_marginal_relevance,
node_distance_reranker, node_distance_reranker,
node_fulltext_search, node_fulltext_search,
node_similarity_search, node_similarity_search,
@ -117,12 +118,14 @@ async def edge_search(
if config is None: if config is None:
return [] return []
query_vector = await embedder.create(input=[query])
search_results: list[list[EntityEdge]] = list( search_results: list[list[EntityEdge]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit), edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
edge_similarity_search( edge_similarity_search(
driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
), ),
] ]
) )
@ -135,6 +138,15 @@ async def edge_search(
search_result_uuids = [[edge.uuid for edge in result] for result in search_results] search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids) reranked_uuids = rrf(search_result_uuids)
elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = [
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
for result in search_results
for edge in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == EdgeReranker.node_distance: elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
@ -175,12 +187,14 @@ async def node_search(
if config is None: if config is None:
return [] return []
query_vector = await embedder.create(input=[query])
search_results: list[list[EntityNode]] = list( search_results: list[list[EntityNode]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
node_fulltext_search(driver, query, group_ids, 2 * limit), node_fulltext_search(driver, query, group_ids, 2 * limit),
node_similarity_search( node_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit driver, query_vector, group_ids, 2 * limit, config.sim_min_score
), ),
] ]
) )
@ -192,6 +206,15 @@ async def node_search(
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf: if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids) reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = [
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
for result in search_results
for node in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == NodeReranker.episode_mentions: elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance: elif config.reranker == NodeReranker.node_distance:
@ -217,12 +240,14 @@ async def community_search(
if config is None: if config is None:
return [] return []
query_vector = await embedder.create(input=[query])
search_results: list[list[CommunityNode]] = list( search_results: list[list[CommunityNode]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
community_fulltext_search(driver, query, group_ids, 2 * limit), community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search( community_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit driver, query_vector, group_ids, 2 * limit, config.sim_min_score
), ),
] ]
) )
@ -236,6 +261,18 @@ async def community_search(
reranked_uuids: list[str] = [] reranked_uuids: list[str] = []
if config.reranker == CommunityReranker.rrf: if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids) reranked_uuids = rrf(search_result_uuids)
elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = [
(
community.uuid,
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
)
for result in search_results
for community in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

View file

@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import CommunityNode, EntityNode from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
DEFAULT_SEARCH_LIMIT = 10 DEFAULT_SEARCH_LIMIT = 10
@ -43,31 +44,40 @@ class EdgeReranker(Enum):
rrf = 'reciprocal_rank_fusion' rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance' node_distance = 'node_distance'
episode_mentions = 'episode_mentions' episode_mentions = 'episode_mentions'
mmr = 'mmr'
class NodeReranker(Enum): class NodeReranker(Enum):
rrf = 'reciprocal_rank_fusion' rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance' node_distance = 'node_distance'
episode_mentions = 'episode_mentions' episode_mentions = 'episode_mentions'
mmr = 'mmr'
class CommunityReranker(Enum): class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion' rrf = 'reciprocal_rank_fusion'
mmr = 'mmr'
class EdgeSearchConfig(BaseModel): class EdgeSearchConfig(BaseModel):
search_methods: list[EdgeSearchMethod] search_methods: list[EdgeSearchMethod]
reranker: EdgeReranker = Field(default=EdgeReranker.rrf) reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
class NodeSearchConfig(BaseModel): class NodeSearchConfig(BaseModel):
search_methods: list[NodeSearchMethod] search_methods: list[NodeSearchMethod]
reranker: NodeReranker = Field(default=NodeReranker.rrf) reranker: NodeReranker = Field(default=NodeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
class CommunitySearchConfig(BaseModel): class CommunitySearchConfig(BaseModel):
search_methods: list[CommunitySearchMethod] search_methods: list[CommunitySearchMethod]
reranker: CommunityReranker = Field(default=CommunityReranker.rrf) reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
class SearchConfig(BaseModel): class SearchConfig(BaseModel):

View file

@ -43,6 +43,22 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
), ),
) )
# Performs a hybrid search with mmr reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
),
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
),
)
# performs a hybrid search over edges with rrf reranking # performs a hybrid search over edges with rrf reranking
EDGE_HYBRID_SEARCH_RRF = SearchConfig( EDGE_HYBRID_SEARCH_RRF = SearchConfig(
edge_config=EdgeSearchConfig( edge_config=EdgeSearchConfig(
@ -51,6 +67,14 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
) )
) )
# performs a hybrid search over edges with mmr reranking
EDGE_HYBRID_SEARCH_mmr = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
)
)
# performs a hybrid search over edges with node distance reranking # performs a hybrid search over edges with node distance reranking
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
edge_config=EdgeSearchConfig( edge_config=EdgeSearchConfig(
@ -75,6 +99,14 @@ NODE_HYBRID_SEARCH_RRF = SearchConfig(
) )
) )
# performs a hybrid search over nodes with mmr reranking
NODE_HYBRID_SEARCH_MMR = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
)
)
# performs a hybrid search over nodes with node distance reranking # performs a hybrid search over nodes with node distance reranking
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
node_config=NodeSearchConfig( node_config=NodeSearchConfig(
@ -98,3 +130,11 @@ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
reranker=CommunityReranker.rrf, reranker=CommunityReranker.rrf,
) )
) )
# performs a hybrid search over communities with mmr reranking
COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
)
)

View file

@ -19,10 +19,11 @@ import logging
from collections import defaultdict from collections import defaultdict
from time import time from time import time
import numpy as np
from neo4j import AsyncDriver, Query from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import lucene_sanitize from graphiti_core.helpers import lucene_sanitize, normalize_l2
from graphiti_core.nodes import ( from graphiti_core.nodes import (
CommunityNode, CommunityNode,
EntityNode, EntityNode,
@ -34,6 +35,8 @@ from graphiti_core.nodes import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3 RELEVANT_SCHEMA_LIMIT = 3
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
def fulltext_query(query: str, group_ids: list[str] | None = None): def fulltext_query(query: str, group_ids: list[str] | None = None):
@ -53,10 +56,10 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
async def get_episodes_by_mentions( async def get_episodes_by_mentions(
driver: AsyncDriver, driver: AsyncDriver,
nodes: list[EntityNode], nodes: list[EntityNode],
edges: list[EntityEdge], edges: list[EntityEdge],
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
episode_uuids: list[str] = [] episode_uuids: list[str] = []
for edge in edges: for edge in edges:
@ -68,7 +71,7 @@ async def get_episodes_by_mentions(
async def get_mentioned_nodes( async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode] driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]: ) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes] episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -91,7 +94,7 @@ async def get_mentioned_nodes(
async def get_communities_by_nodes( async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode] driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]: ) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes] node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -114,12 +117,12 @@ async def get_communities_by_nodes(
async def edge_fulltext_search( async def edge_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# fulltext search over facts # fulltext search over facts
fuzzy_query = fulltext_query(query, group_ids) fuzzy_query = fulltext_query(query, group_ids)
@ -159,12 +162,13 @@ async def edge_fulltext_search(
async def edge_similarity_search( async def edge_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
query = Query(""" query = Query("""
@ -174,7 +178,7 @@ async def edge_similarity_search(
AND ($source_uuid IS NULL OR n.uuid = $source_uuid) AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid) AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > 0.6 WHERE score > $min_score
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -199,6 +203,7 @@ async def edge_similarity_search(
target_uuid=target_node_uuid, target_uuid=target_node_uuid,
group_ids=group_ids, group_ids=group_ids,
limit=limit, limit=limit,
min_score=min_score,
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -207,10 +212,10 @@ async def edge_similarity_search(
async def node_fulltext_search( async def node_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
# BM25 search to get top nodes # BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids) fuzzy_query = fulltext_query(query, group_ids)
@ -239,10 +244,11 @@ async def node_fulltext_search(
async def node_similarity_search( async def node_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]: ) -> list[EntityNode]:
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -251,7 +257,7 @@ async def node_similarity_search(
MATCH (n:Entity) MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > 0.6 WHERE score > $min_score
RETURN RETURN
n.uuid As uuid, n.uuid As uuid,
n.group_id AS group_id, n.group_id AS group_id,
@ -265,6 +271,7 @@ async def node_similarity_search(
search_vector=search_vector, search_vector=search_vector,
group_ids=group_ids, group_ids=group_ids,
limit=limit, limit=limit,
min_score=min_score,
) )
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
@ -272,10 +279,10 @@ async def node_similarity_search(
async def community_fulltext_search( async def community_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
# BM25 search to get top communities # BM25 search to get top communities
fuzzy_query = fulltext_query(query, group_ids) fuzzy_query = fulltext_query(query, group_ids)
@ -304,10 +311,11 @@ async def community_fulltext_search(
async def community_similarity_search( async def community_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -316,7 +324,7 @@ async def community_similarity_search(
MATCH (comm:Community) MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > 0.6 WHERE score > $min_score
RETURN RETURN
comm.uuid As uuid, comm.uuid As uuid,
comm.group_id AS group_id, comm.group_id AS group_id,
@ -330,6 +338,7 @@ async def community_similarity_search(
search_vector=search_vector, search_vector=search_vector,
group_ids=group_ids, group_ids=group_ids,
limit=limit, limit=limit,
min_score=min_score,
) )
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
@ -337,11 +346,11 @@ async def community_similarity_search(
async def hybrid_node_search( async def hybrid_node_search(
queries: list[str], queries: list[str],
embeddings: list[list[float]], embeddings: list[list[float]],
driver: AsyncDriver, driver: AsyncDriver,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
Perform a hybrid search for nodes using both text queries and embeddings. Perform a hybrid search for nodes using both text queries and embeddings.
@ -404,8 +413,8 @@ async def hybrid_node_search(
async def get_relevant_nodes( async def get_relevant_nodes(
nodes: list[EntityNode], nodes: list[EntityNode],
driver: AsyncDriver, driver: AsyncDriver,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
Retrieve relevant nodes based on the provided list of EntityNodes. Retrieve relevant nodes based on the provided list of EntityNodes.
@ -442,11 +451,11 @@ async def get_relevant_nodes(
async def get_relevant_edges( async def get_relevant_edges(
driver: AsyncDriver, driver: AsyncDriver,
edges: list[EntityEdge], edges: list[EntityEdge],
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()
relevant_edges: list[EntityEdge] = [] relevant_edges: list[EntityEdge] = []
@ -503,7 +512,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
async def node_distance_reranker( async def node_distance_reranker(
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
) -> list[str]: ) -> list[str]:
# filter out node_uuid center node node uuid # filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids)) filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
@ -570,3 +579,24 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids return sorted_uuids
def maximal_marginal_relevance(
query_vector: list[float],
candidates: list[tuple[str, list[float]]],
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
):
candidates_with_mmr: list[tuple[str, float]] = []
for candidate in candidates:
max_sim = max(
[
np.dot(normalize_l2(candidate[1]), normalize_l2(c[1]))
for c in candidates
]
)
mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim
candidates_with_mmr.append((candidate[0], mmr))
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
return [candidate[0] for candidate in candidates_with_mmr]

View file

@ -15,7 +15,6 @@ from graphiti_core.utils.maintenance.edge_operations import build_community_edge
MAX_COMMUNITY_BUILD_CONCURRENCY = 10 MAX_COMMUNITY_BUILD_CONCURRENCY = 10
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,31 +23,20 @@ class Neighbor(BaseModel):
edge_count: int edge_count: int
async def build_community_projection(driver: AsyncDriver) -> str: async def get_community_clusters(
records, _, _ = await driver.execute_query(""" driver: AsyncDriver, group_ids: list[str] | None
CALL gds.graph.project("communities", "Entity", ) -> list[list[EntityNode]]:
{RELATES_TO: {
type: "RELATES_TO",
orientation: "UNDIRECTED",
properties: {weight: {property: "*", aggregation: "COUNT"}}
}}
)
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
""")
return records[0]['graph']
async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
community_clusters: list[list[EntityNode]] = [] community_clusters: list[list[EntityNode]] = []
group_id_values, _, _ = await driver.execute_query(""" if group_ids is None:
MATCH (n:Entity WHERE n.group_id IS NOT NULL) group_id_values, _, _ = await driver.execute_query("""
RETURN MATCH (n:Entity WHERE n.group_id IS NOT NULL)
collect(DISTINCT n.group_id) AS group_ids RETURN
""") collect(DISTINCT n.group_id) AS group_ids
""")
group_ids = group_id_values[0]['group_ids']
group_ids = group_id_values[0]['group_ids']
for group_id in group_ids: for group_id in group_ids:
projection: dict[str, list[Neighbor]] = {} projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id]) nodes = await EntityNode.get_by_group_ids(driver, [group_id])
@ -197,9 +185,9 @@ async def build_community(
async def build_communities( async def build_communities(
driver: AsyncDriver, llm_client: LLMClient driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
) -> tuple[list[CommunityNode], list[CommunityEdge]]: ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver) community_clusters = await get_community_clusters(driver, group_ids)
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY) semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "graphiti-core" name = "graphiti-core"
version = "0.3.8" version = "0.3.9"
description = "A temporal graph building library" description = "A temporal graph building library"
authors = [ authors = [
"Paul Paliychuk <paul@getzep.com>", "Paul Paliychuk <paul@getzep.com>",

View file

@ -85,9 +85,7 @@ async def test_graphiti_init():
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
results = await graphiti._search( results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None
)
pretty_results = { pretty_results = {
'edges': [edge.fact for edge in results.edges], 'edges': [edge.fact for edge in results.edges],
'nodes': [node.name for node in results.nodes], 'nodes': [node.name for node in results.nodes],