Mmr optimizations (#481)
* update mmr calculations * update search * fixes and updates * mypy
This commit is contained in:
parent
4198483993
commit
9baa9b7b8a
6 changed files with 146 additions and 76 deletions
|
|
@ -23,6 +23,7 @@ from typing import Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from neo4j import time as neo4j_time
|
from neo4j import time as neo4j_time
|
||||||
|
from numpy._typing import NDArray
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -79,16 +80,10 @@ def lucene_sanitize(query: str) -> str:
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
def normalize_l2(embedding: list[float]):
|
def normalize_l2(embedding: list[float]) -> NDArray:
|
||||||
embedding_array = np.array(embedding)
|
embedding_array = np.array(embedding)
|
||||||
if embedding_array.ndim == 1:
|
norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
|
||||||
norm = np.linalg.norm(embedding_array)
|
return np.where(norm == 0, embedding_array, embedding_array / norm)
|
||||||
if norm == 0:
|
|
||||||
return [0.0] * len(embedding)
|
|
||||||
return (embedding_array / norm).tolist()
|
|
||||||
else:
|
|
||||||
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
|
||||||
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
|
||||||
|
|
||||||
|
|
||||||
# Use this instead of asyncio.gather() to bound coroutines
|
# Use this instead of asyncio.gather() to bound coroutines
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,9 @@ from graphiti_core.search.search_utils import (
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
episode_fulltext_search,
|
episode_fulltext_search,
|
||||||
episode_mentions_reranker,
|
episode_mentions_reranker,
|
||||||
|
get_embeddings_for_communities,
|
||||||
|
get_embeddings_for_edges,
|
||||||
|
get_embeddings_for_nodes,
|
||||||
maximal_marginal_relevance,
|
maximal_marginal_relevance,
|
||||||
node_bfs_search,
|
node_bfs_search,
|
||||||
node_distance_reranker,
|
node_distance_reranker,
|
||||||
|
|
@ -209,26 +212,17 @@ async def edge_search(
|
||||||
|
|
||||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||||
elif config.reranker == EdgeReranker.mmr:
|
elif config.reranker == EdgeReranker.mmr:
|
||||||
await semaphore_gather(
|
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
||||||
*[edge.load_fact_embedding(driver) for result in search_results for edge in result]
|
driver, list(edge_uuid_map.values())
|
||||||
)
|
)
|
||||||
search_result_uuids_and_vectors = [
|
|
||||||
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
|
||||||
for result in search_results
|
|
||||||
for edge in result
|
|
||||||
]
|
|
||||||
reranked_uuids = maximal_marginal_relevance(
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
query_vector,
|
query_vector,
|
||||||
search_result_uuids_and_vectors,
|
search_result_uuids_and_vectors,
|
||||||
config.mmr_lambda,
|
config.mmr_lambda,
|
||||||
|
reranker_min_score,
|
||||||
)
|
)
|
||||||
elif config.reranker == EdgeReranker.cross_encoder:
|
elif config.reranker == EdgeReranker.cross_encoder:
|
||||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
|
||||||
|
|
||||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
||||||
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
||||||
|
|
||||||
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
|
||||||
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
||||||
reranked_uuids = [
|
reranked_uuids = [
|
||||||
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
||||||
|
|
@ -311,30 +305,23 @@ async def node_search(
|
||||||
if config.reranker == NodeReranker.rrf:
|
if config.reranker == NodeReranker.rrf:
|
||||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||||
elif config.reranker == NodeReranker.mmr:
|
elif config.reranker == NodeReranker.mmr:
|
||||||
await semaphore_gather(
|
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
||||||
*[node.load_name_embedding(driver) for result in search_results for node in result]
|
driver, list(node_uuid_map.values())
|
||||||
)
|
)
|
||||||
search_result_uuids_and_vectors = [
|
|
||||||
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
|
||||||
for result in search_results
|
|
||||||
for node in result
|
|
||||||
]
|
|
||||||
reranked_uuids = maximal_marginal_relevance(
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
query_vector,
|
query_vector,
|
||||||
search_result_uuids_and_vectors,
|
search_result_uuids_and_vectors,
|
||||||
config.mmr_lambda,
|
config.mmr_lambda,
|
||||||
|
reranker_min_score,
|
||||||
)
|
)
|
||||||
elif config.reranker == NodeReranker.cross_encoder:
|
elif config.reranker == NodeReranker.cross_encoder:
|
||||||
# use rrf as a preliminary reranker
|
name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
|
||||||
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
||||||
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
||||||
|
|
||||||
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
|
reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
|
||||||
|
|
||||||
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
||||||
reranked_uuids = [
|
reranked_uuids = [
|
||||||
summary_to_uuid_map[fact]
|
name_to_uuid_map[name]
|
||||||
for fact, score in reranked_summaries
|
for name, score in reranked_node_names
|
||||||
if score >= reranker_min_score
|
if score >= reranker_min_score
|
||||||
]
|
]
|
||||||
elif config.reranker == NodeReranker.episode_mentions:
|
elif config.reranker == NodeReranker.episode_mentions:
|
||||||
|
|
@ -437,25 +424,12 @@ async def community_search(
|
||||||
if config.reranker == CommunityReranker.rrf:
|
if config.reranker == CommunityReranker.rrf:
|
||||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||||
elif config.reranker == CommunityReranker.mmr:
|
elif config.reranker == CommunityReranker.mmr:
|
||||||
await semaphore_gather(
|
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
||||||
*[
|
driver, list(community_uuid_map.values())
|
||||||
community.load_name_embedding(driver)
|
|
||||||
for result in search_results
|
|
||||||
for community in result
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
search_result_uuids_and_vectors = [
|
|
||||||
(
|
|
||||||
community.uuid,
|
|
||||||
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
|
|
||||||
)
|
|
||||||
for result in search_results
|
|
||||||
for community in result
|
|
||||||
]
|
|
||||||
reranked_uuids = maximal_marginal_relevance(
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
query_vector,
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
||||||
search_result_uuids_and_vectors,
|
|
||||||
config.mmr_lambda,
|
|
||||||
)
|
)
|
||||||
elif config.reranker == CommunityReranker.cross_encoder:
|
elif config.reranker == CommunityReranker.cross_encoder:
|
||||||
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
|
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
from numpy._typing import NDArray
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
|
|
@ -336,10 +337,10 @@ async def node_fulltext_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WHERE n:Entity
|
WHERE n:Entity
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -899,6 +900,7 @@ async def node_distance_reranker(
|
||||||
node_uuids=filtered_uuids,
|
node_uuids=filtered_uuids,
|
||||||
center_uuid=center_node_uuid,
|
center_uuid=center_node_uuid,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in path_results:
|
for result in path_results:
|
||||||
|
|
@ -939,6 +941,7 @@ async def episode_mentions_reranker(
|
||||||
query,
|
query,
|
||||||
node_uuids=sorted_uuids,
|
node_uuids=sorted_uuids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
@ -952,15 +955,116 @@ async def episode_mentions_reranker(
|
||||||
|
|
||||||
def maximal_marginal_relevance(
|
def maximal_marginal_relevance(
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
candidates: list[tuple[str, list[float]]],
|
candidates: dict[str, list[float]],
|
||||||
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
||||||
):
|
min_score: float = -2.0,
|
||||||
candidates_with_mmr: list[tuple[str, float]] = []
|
) -> list[str]:
|
||||||
for candidate in candidates:
|
start = time()
|
||||||
max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
|
query_array = np.array(query_vector)
|
||||||
mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
|
candidate_arrays: dict[str, NDArray] = {}
|
||||||
candidates_with_mmr.append((candidate[0], mmr))
|
for uuid, embedding in candidates.items():
|
||||||
|
candidate_arrays[uuid] = normalize_l2(embedding)
|
||||||
|
|
||||||
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
uuids: list[str] = list(candidate_arrays.keys())
|
||||||
|
|
||||||
return list(set([candidate[0] for candidate in candidates_with_mmr]))
|
similarity_matrix = np.zeros((len(uuids), len(uuids)))
|
||||||
|
|
||||||
|
for i, uuid_1 in enumerate(uuids):
|
||||||
|
for j, uuid_2 in enumerate(uuids[:i]):
|
||||||
|
u = candidate_arrays[uuid_1]
|
||||||
|
v = candidate_arrays[uuid_2]
|
||||||
|
similarity = np.dot(u, v)
|
||||||
|
|
||||||
|
similarity_matrix[i, j] = similarity
|
||||||
|
similarity_matrix[j, i] = similarity
|
||||||
|
|
||||||
|
mmr_scores: dict[str, float] = {}
|
||||||
|
for i, uuid in enumerate(uuids):
|
||||||
|
max_sim = np.max(similarity_matrix[i, :])
|
||||||
|
mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
|
||||||
|
mmr_scores[uuid] = mmr
|
||||||
|
|
||||||
|
uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
|
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings_for_nodes(
|
||||||
|
driver: AsyncDriver, nodes: list[EntityNode]
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query: LiteralString = """MATCH (n:Entity)
|
||||||
|
WHERE n.uuid IN $node_uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
n.uuid AS uuid,
|
||||||
|
n.name_embedding AS name_embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_dict: dict[str, list[float]] = {}
|
||||||
|
for result in results:
|
||||||
|
uuid: str = result.get('uuid')
|
||||||
|
embedding: list[float] = result.get('name_embedding')
|
||||||
|
if uuid is not None and embedding is not None:
|
||||||
|
embeddings_dict[uuid] = embedding
|
||||||
|
|
||||||
|
return embeddings_dict
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings_for_communities(
|
||||||
|
driver: AsyncDriver, communities: list[CommunityNode]
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query: LiteralString = """MATCH (c:Community)
|
||||||
|
WHERE c.uuid IN $community_uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
c.uuid AS uuid,
|
||||||
|
c.name_embedding AS name_embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
community_uuids=[community.uuid for community in communities],
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_dict: dict[str, list[float]] = {}
|
||||||
|
for result in results:
|
||||||
|
uuid: str = result.get('uuid')
|
||||||
|
embedding: list[float] = result.get('name_embedding')
|
||||||
|
if uuid is not None and embedding is not None:
|
||||||
|
embeddings_dict[uuid] = embedding
|
||||||
|
|
||||||
|
return embeddings_dict
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings_for_edges(
|
||||||
|
driver: AsyncDriver, edges: list[EntityEdge]
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||||
|
WHERE e.uuid IN $edge_uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
e.uuid AS uuid,
|
||||||
|
e.fact_embedding AS fact_embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
edge_uuids=[edge.uuid for edge in edges],
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_dict: dict[str, list[float]] = {}
|
||||||
|
for result in results:
|
||||||
|
uuid: str = result.get('uuid')
|
||||||
|
embedding: list[float] = result.get('fact_embedding')
|
||||||
|
if uuid is not None and embedding is not None:
|
||||||
|
embeddings_dict[uuid] = embedding
|
||||||
|
|
||||||
|
return embeddings_dict
|
||||||
|
|
|
||||||
9
poetry.lock
generated
9
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohappyeyeballs"
|
name = "aiohappyeyeballs"
|
||||||
|
|
@ -333,7 +333,7 @@ description = "Timeout context manager for asyncio programs"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
markers = "python_version < \"3.11\""
|
markers = "python_version == \"3.10\""
|
||||||
files = [
|
files = [
|
||||||
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
|
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
|
||||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||||
|
|
@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
groups = ["main", "dev"]
|
groups = ["main", "dev"]
|
||||||
markers = "python_version < \"3.11\""
|
markers = "python_version == \"3.10\""
|
||||||
files = [
|
files = [
|
||||||
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
||||||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||||
|
|
@ -2668,7 +2668,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
markers = "platform_python_implementation != \"PyPy\""
|
|
||||||
files = [
|
files = [
|
||||||
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
|
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
|
||||||
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
|
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
|
||||||
|
|
@ -4500,7 +4499,7 @@ description = "A lil' TOML parser"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
markers = "python_version < \"3.11\""
|
markers = "python_version == \"3.10\""
|
||||||
files = [
|
files = [
|
||||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.11.6pre7"
|
version = "0.11.6pre8"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||||
|
|
|
||||||
|
|
@ -65,9 +65,7 @@ async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
|
|
||||||
results = await graphiti.search_(
|
results = await graphiti.search_(query='Who is the user?')
|
||||||
query='Who is the user?',
|
|
||||||
)
|
|
||||||
|
|
||||||
pretty_results = search_results_to_context_string(results)
|
pretty_results = search_results_to_context_string(results)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue