Add mmr reranking (#180)
* mmr start * add mmr function * normalize * add mmr options to search * update communities * build communities * format * clean up normalization * normalize in mmr * update
This commit is contained in:
parent
5508dba1b3
commit
49aeaf75f2
11 changed files with 215 additions and 88 deletions
|
|
@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
||||||
return result.data[0].embedding[: self.config.embedding_dim]
|
return result.data[0].embedding[: self.config.embedding_dim]
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class VoyageAIEmbedder(EmbedderClient):
|
||||||
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
result = await self.client.embed(input, model=self.config.embedding_model)
|
result = await self.client.embed(input, model=self.config.embedding_model)
|
||||||
return result.embeddings[0][: self.config.embedding_dim]
|
return result.embeddings[0][: self.config.embedding_dim]
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import SearchConfig, search
|
from graphiti_core.search.search import SearchConfig, search
|
||||||
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
||||||
from graphiti_core.search.search_config_recipes import (
|
from graphiti_core.search.search_config_recipes import (
|
||||||
|
|
@ -576,11 +576,20 @@ class Graphiti:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def build_communities(self):
|
async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
|
||||||
|
"""
|
||||||
|
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||||
|
the content of these communities.
|
||||||
|
----------
|
||||||
|
query : list[str] | None
|
||||||
|
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
||||||
|
"""
|
||||||
# Clear existing communities
|
# Clear existing communities
|
||||||
await remove_communities(self.driver)
|
await remove_communities(self.driver)
|
||||||
|
|
||||||
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
community_nodes, community_edges = await build_communities(
|
||||||
|
self.driver, self.llm_client, group_ids
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
||||||
|
|
@ -589,6 +598,8 @@ class Graphiti:
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
||||||
|
|
||||||
|
return community_nodes
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from neo4j import time as neo4j_time
|
from neo4j import time as neo4j_time
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,3 +53,15 @@ def lucene_sanitize(query: str) -> str:
|
||||||
|
|
||||||
sanitized = query.translate(escape_map)
|
sanitized = query.translate(escape_map)
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_l2(embedding: list[float]) -> list[float]:
|
||||||
|
embedding_array = np.array(embedding)
|
||||||
|
if embedding_array.ndim == 1:
|
||||||
|
norm = np.linalg.norm(embedding_array)
|
||||||
|
if norm == 0:
|
||||||
|
return embedding_array.tolist()
|
||||||
|
return (embedding_array / norm).tolist()
|
||||||
|
else:
|
||||||
|
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
||||||
|
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ from graphiti_core.search.search_utils import (
|
||||||
edge_fulltext_search,
|
edge_fulltext_search,
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
episode_mentions_reranker,
|
episode_mentions_reranker,
|
||||||
|
maximal_marginal_relevance,
|
||||||
node_distance_reranker,
|
node_distance_reranker,
|
||||||
node_fulltext_search,
|
node_fulltext_search,
|
||||||
node_similarity_search,
|
node_similarity_search,
|
||||||
|
|
@ -117,12 +118,14 @@ async def edge_search(
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
query_vector = await embedder.create(input=[query])
|
||||||
|
|
||||||
search_results: list[list[EntityEdge]] = list(
|
search_results: list[list[EntityEdge]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
|
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
|
||||||
edge_similarity_search(
|
edge_similarity_search(
|
||||||
driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit
|
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -135,6 +138,15 @@ async def edge_search(
|
||||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||||
|
|
||||||
reranked_uuids = rrf(search_result_uuids)
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
elif config.reranker == EdgeReranker.mmr:
|
||||||
|
search_result_uuids_and_vectors = [
|
||||||
|
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
||||||
|
for result in search_results
|
||||||
|
for edge in result
|
||||||
|
]
|
||||||
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||||
|
)
|
||||||
elif config.reranker == EdgeReranker.node_distance:
|
elif config.reranker == EdgeReranker.node_distance:
|
||||||
if center_node_uuid is None:
|
if center_node_uuid is None:
|
||||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||||
|
|
@ -175,12 +187,14 @@ async def node_search(
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
query_vector = await embedder.create(input=[query])
|
||||||
|
|
||||||
search_results: list[list[EntityNode]] = list(
|
search_results: list[list[EntityNode]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||||
node_similarity_search(
|
node_similarity_search(
|
||||||
driver, await embedder.create(input=[query]), group_ids, 2 * limit
|
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -192,6 +206,15 @@ async def node_search(
|
||||||
reranked_uuids: list[str] = []
|
reranked_uuids: list[str] = []
|
||||||
if config.reranker == NodeReranker.rrf:
|
if config.reranker == NodeReranker.rrf:
|
||||||
reranked_uuids = rrf(search_result_uuids)
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
elif config.reranker == NodeReranker.mmr:
|
||||||
|
search_result_uuids_and_vectors = [
|
||||||
|
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
||||||
|
for result in search_results
|
||||||
|
for node in result
|
||||||
|
]
|
||||||
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||||
|
)
|
||||||
elif config.reranker == NodeReranker.episode_mentions:
|
elif config.reranker == NodeReranker.episode_mentions:
|
||||||
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
||||||
elif config.reranker == NodeReranker.node_distance:
|
elif config.reranker == NodeReranker.node_distance:
|
||||||
|
|
@ -217,12 +240,14 @@ async def community_search(
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
query_vector = await embedder.create(input=[query])
|
||||||
|
|
||||||
search_results: list[list[CommunityNode]] = list(
|
search_results: list[list[CommunityNode]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||||
community_similarity_search(
|
community_similarity_search(
|
||||||
driver, await embedder.create(input=[query]), group_ids, 2 * limit
|
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -236,6 +261,18 @@ async def community_search(
|
||||||
reranked_uuids: list[str] = []
|
reranked_uuids: list[str] = []
|
||||||
if config.reranker == CommunityReranker.rrf:
|
if config.reranker == CommunityReranker.rrf:
|
||||||
reranked_uuids = rrf(search_result_uuids)
|
reranked_uuids = rrf(search_result_uuids)
|
||||||
|
elif config.reranker == CommunityReranker.mmr:
|
||||||
|
search_result_uuids_and_vectors = [
|
||||||
|
(
|
||||||
|
community.uuid,
|
||||||
|
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
|
||||||
|
)
|
||||||
|
for result in search_results
|
||||||
|
for community in result
|
||||||
|
]
|
||||||
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||||
|
)
|
||||||
|
|
||||||
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||||
|
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
|
||||||
|
|
||||||
DEFAULT_SEARCH_LIMIT = 10
|
DEFAULT_SEARCH_LIMIT = 10
|
||||||
|
|
||||||
|
|
@ -43,31 +44,40 @@ class EdgeReranker(Enum):
|
||||||
rrf = 'reciprocal_rank_fusion'
|
rrf = 'reciprocal_rank_fusion'
|
||||||
node_distance = 'node_distance'
|
node_distance = 'node_distance'
|
||||||
episode_mentions = 'episode_mentions'
|
episode_mentions = 'episode_mentions'
|
||||||
|
mmr = 'mmr'
|
||||||
|
|
||||||
|
|
||||||
class NodeReranker(Enum):
|
class NodeReranker(Enum):
|
||||||
rrf = 'reciprocal_rank_fusion'
|
rrf = 'reciprocal_rank_fusion'
|
||||||
node_distance = 'node_distance'
|
node_distance = 'node_distance'
|
||||||
episode_mentions = 'episode_mentions'
|
episode_mentions = 'episode_mentions'
|
||||||
|
mmr = 'mmr'
|
||||||
|
|
||||||
|
|
||||||
class CommunityReranker(Enum):
|
class CommunityReranker(Enum):
|
||||||
rrf = 'reciprocal_rank_fusion'
|
rrf = 'reciprocal_rank_fusion'
|
||||||
|
mmr = 'mmr'
|
||||||
|
|
||||||
|
|
||||||
class EdgeSearchConfig(BaseModel):
|
class EdgeSearchConfig(BaseModel):
|
||||||
search_methods: list[EdgeSearchMethod]
|
search_methods: list[EdgeSearchMethod]
|
||||||
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
||||||
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
|
|
||||||
|
|
||||||
class NodeSearchConfig(BaseModel):
|
class NodeSearchConfig(BaseModel):
|
||||||
search_methods: list[NodeSearchMethod]
|
search_methods: list[NodeSearchMethod]
|
||||||
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
||||||
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
|
|
||||||
|
|
||||||
class CommunitySearchConfig(BaseModel):
|
class CommunitySearchConfig(BaseModel):
|
||||||
search_methods: list[CommunitySearchMethod]
|
search_methods: list[CommunitySearchMethod]
|
||||||
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
||||||
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(BaseModel):
|
class SearchConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,22 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Performs a hybrid search with mmr reranking over edges, nodes, and communities
|
||||||
|
COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
||||||
|
edge_config=EdgeSearchConfig(
|
||||||
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
|
reranker=EdgeReranker.mmr,
|
||||||
|
),
|
||||||
|
node_config=NodeSearchConfig(
|
||||||
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||||
|
reranker=NodeReranker.mmr,
|
||||||
|
),
|
||||||
|
community_config=CommunitySearchConfig(
|
||||||
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||||
|
reranker=CommunityReranker.mmr,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# performs a hybrid search over edges with rrf reranking
|
# performs a hybrid search over edges with rrf reranking
|
||||||
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
edge_config=EdgeSearchConfig(
|
edge_config=EdgeSearchConfig(
|
||||||
|
|
@ -51,6 +67,14 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over edges with mmr reranking
|
||||||
|
EDGE_HYBRID_SEARCH_mmr = SearchConfig(
|
||||||
|
edge_config=EdgeSearchConfig(
|
||||||
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
|
reranker=EdgeReranker.mmr,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# performs a hybrid search over edges with node distance reranking
|
# performs a hybrid search over edges with node distance reranking
|
||||||
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
edge_config=EdgeSearchConfig(
|
edge_config=EdgeSearchConfig(
|
||||||
|
|
@ -75,6 +99,14 @@ NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over nodes with mmr reranking
|
||||||
|
NODE_HYBRID_SEARCH_MMR = SearchConfig(
|
||||||
|
node_config=NodeSearchConfig(
|
||||||
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||||
|
reranker=NodeReranker.mmr,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# performs a hybrid search over nodes with node distance reranking
|
# performs a hybrid search over nodes with node distance reranking
|
||||||
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
node_config=NodeSearchConfig(
|
node_config=NodeSearchConfig(
|
||||||
|
|
@ -98,3 +130,11 @@ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
||||||
reranker=CommunityReranker.rrf,
|
reranker=CommunityReranker.rrf,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# performs a hybrid search over communities with mmr reranking
|
||||||
|
COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
|
||||||
|
community_config=CommunitySearchConfig(
|
||||||
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||||
|
reranker=CommunityReranker.mmr,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,11 @@ import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
from graphiti_core.helpers import lucene_sanitize
|
from graphiti_core.helpers import lucene_sanitize, normalize_l2
|
||||||
from graphiti_core.nodes import (
|
from graphiti_core.nodes import (
|
||||||
CommunityNode,
|
CommunityNode,
|
||||||
EntityNode,
|
EntityNode,
|
||||||
|
|
@ -34,6 +35,8 @@ from graphiti_core.nodes import (
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
|
DEFAULT_MIN_SCORE = 0.6
|
||||||
|
DEFAULT_MMR_LAMBDA = 0.5
|
||||||
|
|
||||||
|
|
||||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
|
|
@ -53,10 +56,10 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
|
|
||||||
|
|
||||||
async def get_episodes_by_mentions(
|
async def get_episodes_by_mentions(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
episode_uuids: list[str] = []
|
episode_uuids: list[str] = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
|
|
@ -68,7 +71,7 @@ async def get_episodes_by_mentions(
|
||||||
|
|
||||||
|
|
||||||
async def get_mentioned_nodes(
|
async def get_mentioned_nodes(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -91,7 +94,7 @@ async def get_mentioned_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_communities_by_nodes(
|
async def get_communities_by_nodes(
|
||||||
driver: AsyncDriver, nodes: list[EntityNode]
|
driver: AsyncDriver, nodes: list[EntityNode]
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
node_uuids = [node.uuid for node in nodes]
|
node_uuids = [node.uuid for node in nodes]
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -114,12 +117,12 @@ async def get_communities_by_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
fuzzy_query = fulltext_query(query, group_ids)
|
fuzzy_query = fulltext_query(query, group_ids)
|
||||||
|
|
@ -159,12 +162,13 @@ async def edge_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
async def edge_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
query = Query("""
|
query = Query("""
|
||||||
|
|
@ -174,7 +178,7 @@ async def edge_similarity_search(
|
||||||
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
||||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
||||||
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
WHERE score > 0.6
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -199,6 +203,7 @@ async def edge_similarity_search(
|
||||||
target_uuid=target_node_uuid,
|
target_uuid=target_node_uuid,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -207,10 +212,10 @@ async def edge_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_fulltext_search(
|
async def node_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids)
|
fuzzy_query = fulltext_query(query, group_ids)
|
||||||
|
|
@ -239,10 +244,11 @@ async def node_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_similarity_search(
|
async def node_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -251,7 +257,7 @@ async def node_similarity_search(
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||||
WHERE score > 0.6
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
|
|
@ -265,6 +271,7 @@ async def node_similarity_search(
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
)
|
)
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -272,10 +279,10 @@ async def node_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_fulltext_search(
|
async def community_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# BM25 search to get top communities
|
# BM25 search to get top communities
|
||||||
fuzzy_query = fulltext_query(query, group_ids)
|
fuzzy_query = fulltext_query(query, group_ids)
|
||||||
|
|
@ -304,10 +311,11 @@ async def community_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_similarity_search(
|
async def community_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
min_score=DEFAULT_MIN_SCORE,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -316,7 +324,7 @@ async def community_similarity_search(
|
||||||
MATCH (comm:Community)
|
MATCH (comm:Community)
|
||||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||||
WHERE score > 0.6
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid As uuid,
|
comm.uuid As uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
|
|
@ -330,6 +338,7 @@ async def community_similarity_search(
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
)
|
)
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -337,11 +346,11 @@ async def community_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_node_search(
|
async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search for nodes using both text queries and embeddings.
|
Perform a hybrid search for nodes using both text queries and embeddings.
|
||||||
|
|
@ -404,8 +413,8 @@ async def hybrid_node_search(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_nodes(
|
async def get_relevant_nodes(
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve relevant nodes based on the provided list of EntityNodes.
|
Retrieve relevant nodes based on the provided list of EntityNodes.
|
||||||
|
|
@ -442,11 +451,11 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
relevant_edges: list[EntityEdge] = []
|
relevant_edges: list[EntityEdge] = []
|
||||||
|
|
@ -503,7 +512,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
async def node_distance_reranker(
|
async def node_distance_reranker(
|
||||||
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# filter out node_uuid center node node uuid
|
# filter out node_uuid center node node uuid
|
||||||
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
||||||
|
|
@ -570,3 +579,24 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
||||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||||
|
|
||||||
return sorted_uuids
|
return sorted_uuids
|
||||||
|
|
||||||
|
|
||||||
|
def maximal_marginal_relevance(
|
||||||
|
query_vector: list[float],
|
||||||
|
candidates: list[tuple[str, list[float]]],
|
||||||
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
||||||
|
):
|
||||||
|
candidates_with_mmr: list[tuple[str, float]] = []
|
||||||
|
for candidate in candidates:
|
||||||
|
max_sim = max(
|
||||||
|
[
|
||||||
|
np.dot(normalize_l2(candidate[1]), normalize_l2(c[1]))
|
||||||
|
for c in candidates
|
||||||
|
]
|
||||||
|
)
|
||||||
|
mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim
|
||||||
|
candidates_with_mmr.append((candidate[0], mmr))
|
||||||
|
|
||||||
|
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
||||||
|
|
||||||
|
return [candidate[0] for candidate in candidates_with_mmr]
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from graphiti_core.utils.maintenance.edge_operations import build_community_edge
|
||||||
|
|
||||||
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,31 +23,20 @@ class Neighbor(BaseModel):
|
||||||
edge_count: int
|
edge_count: int
|
||||||
|
|
||||||
|
|
||||||
async def build_community_projection(driver: AsyncDriver) -> str:
|
async def get_community_clusters(
|
||||||
records, _, _ = await driver.execute_query("""
|
driver: AsyncDriver, group_ids: list[str] | None
|
||||||
CALL gds.graph.project("communities", "Entity",
|
) -> list[list[EntityNode]]:
|
||||||
{RELATES_TO: {
|
|
||||||
type: "RELATES_TO",
|
|
||||||
orientation: "UNDIRECTED",
|
|
||||||
properties: {weight: {property: "*", aggregation: "COUNT"}}
|
|
||||||
}}
|
|
||||||
)
|
|
||||||
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
|
|
||||||
""")
|
|
||||||
|
|
||||||
return records[0]['graph']
|
|
||||||
|
|
||||||
|
|
||||||
async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
|
|
||||||
community_clusters: list[list[EntityNode]] = []
|
community_clusters: list[list[EntityNode]] = []
|
||||||
|
|
||||||
group_id_values, _, _ = await driver.execute_query("""
|
if group_ids is None:
|
||||||
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
group_id_values, _, _ = await driver.execute_query("""
|
||||||
RETURN
|
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
||||||
collect(DISTINCT n.group_id) AS group_ids
|
RETURN
|
||||||
""")
|
collect(DISTINCT n.group_id) AS group_ids
|
||||||
|
""")
|
||||||
|
|
||||||
|
group_ids = group_id_values[0]['group_ids']
|
||||||
|
|
||||||
group_ids = group_id_values[0]['group_ids']
|
|
||||||
for group_id in group_ids:
|
for group_id in group_ids:
|
||||||
projection: dict[str, list[Neighbor]] = {}
|
projection: dict[str, list[Neighbor]] = {}
|
||||||
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
|
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
|
||||||
|
|
@ -197,9 +185,9 @@ async def build_community(
|
||||||
|
|
||||||
|
|
||||||
async def build_communities(
|
async def build_communities(
|
||||||
driver: AsyncDriver, llm_client: LLMClient
|
driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
community_clusters = await get_community_clusters(driver)
|
community_clusters = await get_community_clusters(driver, group_ids)
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
|
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.3.8"
|
version = "0.3.9"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
|
|
@ -85,9 +85,7 @@ async def test_graphiti_init():
|
||||||
|
|
||||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
||||||
|
|
||||||
results = await graphiti._search(
|
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
|
||||||
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None
|
|
||||||
)
|
|
||||||
pretty_results = {
|
pretty_results = {
|
||||||
'edges': [edge.fact for edge in results.edges],
|
'edges': [edge.fact for edge in results.edges],
|
||||||
'nodes': [node.name for node in results.nodes],
|
'nodes': [node.name for node in results.nodes],
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue