make broader use of debug logs (#187)
This commit is contained in:
parent
f52b45b9b2
commit
6c3b32e620
10 changed files with 89 additions and 94 deletions
|
|
@ -51,7 +51,7 @@ class Edge(BaseModel, ABC):
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Deleted Edge: {self.uuid}')
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -83,7 +83,7 @@ class EpisodicEdge(Edge):
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -178,7 +178,7 @@ class EntityEdge(Edge):
|
||||||
self.fact_embedding = await embedder.create(input=[text])
|
self.fact_embedding = await embedder.create(input=[text])
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'embedded {text} in {end - start} ms')
|
logger.debug(f'embedded {text} in {end - start} ms')
|
||||||
|
|
||||||
return self.fact_embedding
|
return self.fact_embedding
|
||||||
|
|
||||||
|
|
@ -206,7 +206,7 @@ class EntityEdge(Edge):
|
||||||
invalid_at=self.invalid_at,
|
invalid_at=self.invalid_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -313,7 +313,7 @@ class CommunityEdge(Edge):
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
||||||
return result.data[0].embedding[: self.config.embedding_dim]
|
return result.data[0].embedding[: self.config.embedding_dim]
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class VoyageAIEmbedder(EmbedderClient):
|
||||||
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
result = await self.client.embed(input, model=self.config.embedding_model)
|
result = await self.client.embed(input, model=self.config.embedding_model)
|
||||||
return result.embeddings[0][: self.config.embedding_dim]
|
return result.embeddings[0][: self.config.embedding_dim]
|
||||||
|
|
|
||||||
|
|
@ -325,7 +325,7 @@ class Graphiti:
|
||||||
# Extract entities as nodes
|
# Extract entities as nodes
|
||||||
|
|
||||||
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
||||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
|
|
||||||
# Calculate Embeddings
|
# Calculate Embeddings
|
||||||
|
|
||||||
|
|
@ -340,7 +340,7 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
|
|
||||||
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
||||||
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
||||||
|
|
@ -348,7 +348,7 @@ class Graphiti:
|
||||||
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
||||||
nodes = mentioned_nodes
|
nodes = mentioned_nodes
|
||||||
|
|
||||||
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
||||||
|
|
@ -378,10 +378,10 @@ class Graphiti:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -433,11 +433,11 @@ class Graphiti:
|
||||||
|
|
||||||
entity_edges.extend(resolved_edges + invalidated_edges)
|
entity_edges.extend(resolved_edges + invalidated_edges)
|
||||||
|
|
||||||
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
||||||
|
|
||||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
||||||
|
|
||||||
logger.info(f'Built episodic edges: {episodic_edges}')
|
logger.debug(f'Built episodic edges: {episodic_edges}')
|
||||||
|
|
||||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||||
|
|
||||||
|
|
@ -563,7 +563,7 @@ class Graphiti:
|
||||||
edges = await dedupe_edges_bulk(
|
edges = await dedupe_edges_bulk(
|
||||||
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
||||||
)
|
)
|
||||||
logger.info(f'extracted edge length: {len(edges)}')
|
logger.debug(f'extracted edge length: {len(edges)}')
|
||||||
|
|
||||||
# invalidate edges
|
# invalidate edges
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ class Node(BaseModel, ABC):
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Deleted Node: {self.uuid}')
|
logger.debug(f'Deleted Node: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -135,7 +135,7 @@ class EpisodicNode(Node):
|
||||||
source=self.source.value,
|
source=self.source.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -217,7 +217,7 @@ class EntityNode(Node):
|
||||||
text = self.name.replace('\n', ' ')
|
text = self.name.replace('\n', ' ')
|
||||||
self.name_embedding = await embedder.create(input=[text])
|
self.name_embedding = await embedder.create(input=[text])
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'embedded {text} in {end - start} ms')
|
logger.debug(f'embedded {text} in {end - start} ms')
|
||||||
|
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
|
|
@ -236,7 +236,7 @@ class EntityNode(Node):
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -320,7 +320,7 @@ class CommunityNode(Node):
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -329,7 +329,7 @@ class CommunityNode(Node):
|
||||||
text = self.name.replace('\n', ' ')
|
text = self.name.replace('\n', ' ')
|
||||||
self.name_embedding = await embedder.create(input=[text])
|
self.name_embedding = await embedder.create(input=[text])
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'embedded {text} in {end - start} ms')
|
logger.debug(f'embedded {text} in {end - start} ms')
|
||||||
|
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,10 +56,10 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
|
|
||||||
|
|
||||||
async def get_episodes_by_mentions(
|
async def get_episodes_by_mentions(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
episode_uuids: list[str] = []
|
episode_uuids: list[str] = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
|
|
@ -71,7 +71,7 @@ async def get_episodes_by_mentions(
|
||||||
|
|
||||||
|
|
||||||
async def get_mentioned_nodes(
|
async def get_mentioned_nodes(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -94,7 +94,7 @@ async def get_mentioned_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_communities_by_nodes(
|
async def get_communities_by_nodes(
|
||||||
driver: AsyncDriver, nodes: list[EntityNode]
|
driver: AsyncDriver, nodes: list[EntityNode]
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
node_uuids = [node.uuid for node in nodes]
|
node_uuids = [node.uuid for node in nodes]
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -117,12 +117,12 @@ async def get_communities_by_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
fuzzy_query = fulltext_query(query, group_ids)
|
fuzzy_query = fulltext_query(query, group_ids)
|
||||||
|
|
@ -162,13 +162,13 @@ async def edge_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
async def edge_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
query = Query("""
|
query = Query("""
|
||||||
|
|
@ -212,10 +212,10 @@ async def edge_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_fulltext_search(
|
async def node_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids)
|
fuzzy_query = fulltext_query(query, group_ids)
|
||||||
|
|
@ -244,11 +244,11 @@ async def node_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_similarity_search(
|
async def node_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -279,10 +279,10 @@ async def node_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_fulltext_search(
|
async def community_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# BM25 search to get top communities
|
# BM25 search to get top communities
|
||||||
fuzzy_query = fulltext_query(query, group_ids)
|
fuzzy_query = fulltext_query(query, group_ids)
|
||||||
|
|
@ -311,11 +311,11 @@ async def community_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_similarity_search(
|
async def community_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
min_score=DEFAULT_MIN_SCORE,
|
min_score=DEFAULT_MIN_SCORE,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -346,11 +346,11 @@ async def community_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_node_search(
|
async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search for nodes using both text queries and embeddings.
|
Perform a hybrid search for nodes using both text queries and embeddings.
|
||||||
|
|
@ -408,13 +408,13 @@ async def hybrid_node_search(
|
||||||
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
|
logger.debug(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
|
||||||
return relevant_nodes
|
return relevant_nodes
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_nodes(
|
async def get_relevant_nodes(
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve relevant nodes based on the provided list of EntityNodes.
|
Retrieve relevant nodes based on the provided list of EntityNodes.
|
||||||
|
|
@ -451,11 +451,11 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
relevant_edges: list[EntityEdge] = []
|
relevant_edges: list[EntityEdge] = []
|
||||||
|
|
@ -491,7 +491,7 @@ async def get_relevant_edges(
|
||||||
relevant_edges.append(edge)
|
relevant_edges.append(edge)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
return relevant_edges
|
return relevant_edges
|
||||||
|
|
||||||
|
|
@ -512,7 +512,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
async def node_distance_reranker(
|
async def node_distance_reranker(
|
||||||
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# filter out node_uuid center node node uuid
|
# filter out node_uuid center node node uuid
|
||||||
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
||||||
|
|
@ -582,18 +582,13 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
||||||
|
|
||||||
|
|
||||||
def maximal_marginal_relevance(
|
def maximal_marginal_relevance(
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
candidates: list[tuple[str, list[float]]],
|
candidates: list[tuple[str, list[float]]],
|
||||||
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
||||||
):
|
):
|
||||||
candidates_with_mmr: list[tuple[str, float]] = []
|
candidates_with_mmr: list[tuple[str, float]] = []
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
max_sim = max(
|
max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
|
||||||
[
|
|
||||||
np.dot(normalize_l2(candidate[1]), normalize_l2(c[1]))
|
|
||||||
for c in candidates
|
|
||||||
]
|
|
||||||
)
|
|
||||||
mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim
|
mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim
|
||||||
candidates_with_mmr.append((candidate[0], mmr))
|
candidates_with_mmr.append((candidate[0], mmr))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,7 @@ async def build_community(
|
||||||
)
|
)
|
||||||
community_edges = build_community_edges(community_cluster, community_node, now)
|
community_edges = build_community_edges(community_cluster, community_node, now)
|
||||||
|
|
||||||
logger.info((community_node, community_edges))
|
logger.debug((community_node, community_edges))
|
||||||
|
|
||||||
return community_node, community_edges
|
return community_node, community_edges
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ async def extract_edges(
|
||||||
edges_data = llm_response.get('edges', [])
|
edges_data = llm_response.get('edges', [])
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
|
logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
# Convert the extracted data into EntityEdge objects
|
# Convert the extracted data into EntityEdge objects
|
||||||
edges = []
|
edges = []
|
||||||
|
|
@ -115,7 +115,7 @@ async def extract_edges(
|
||||||
invalid_at=None,
|
invalid_at=None,
|
||||||
)
|
)
|
||||||
edges.append(edge)
|
edges.append(edge)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -144,7 +144,7 @@ async def dedupe_extracted_edges(
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
||||||
duplicate_data = llm_response.get('duplicates', [])
|
duplicate_data = llm_response.get('duplicates', [])
|
||||||
logger.info(f'Extracted unique edges: {duplicate_data}')
|
logger.debug(f'Extracted unique edges: {duplicate_data}')
|
||||||
|
|
||||||
duplicate_uuid_map: dict[str, str] = {}
|
duplicate_uuid_map: dict[str, str] = {}
|
||||||
for duplicate in duplicate_data:
|
for duplicate in duplicate_data:
|
||||||
|
|
@ -299,7 +299,7 @@ async def dedupe_extracted_edge(
|
||||||
edge = existing_edge
|
edge = existing_edge
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -326,7 +326,7 @@ async def dedupe_edge_list(
|
||||||
unique_edges_data = llm_response.get('unique_facts', [])
|
unique_edges_data = llm_response.get('unique_facts', [])
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
||||||
|
|
||||||
# Get full edge data
|
# Get full edge data
|
||||||
unique_edges = []
|
unique_edges = []
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ async def extract_nodes(
|
||||||
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Extracted new nodes: {extracted_node_data} in {(end - start) * 1000} ms')
|
logger.debug(f'Extracted new nodes: {extracted_node_data} in {(end - start) * 1000} ms')
|
||||||
# Convert the extracted data into EntityNode objects
|
# Convert the extracted data into EntityNode objects
|
||||||
new_nodes = []
|
new_nodes = []
|
||||||
for node_data in extracted_node_data:
|
for node_data in extracted_node_data:
|
||||||
|
|
@ -116,7 +116,7 @@ async def extract_nodes(
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
)
|
)
|
||||||
new_nodes.append(new_node)
|
new_nodes.append(new_node)
|
||||||
logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||||
|
|
||||||
return new_nodes
|
return new_nodes
|
||||||
|
|
||||||
|
|
@ -152,7 +152,7 @@ async def dedupe_extracted_nodes(
|
||||||
duplicate_data = llm_response.get('duplicates', [])
|
duplicate_data = llm_response.get('duplicates', [])
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
logger.debug(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
for duplicate in duplicate_data:
|
for duplicate in duplicate_data:
|
||||||
|
|
@ -232,7 +232,7 @@ async def resolve_extracted_node(
|
||||||
uuid_map[extracted_node.uuid] = existing_node.uuid
|
uuid_map[extracted_node.uuid] = existing_node.uuid
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -266,7 +266,7 @@ async def dedupe_node_list(
|
||||||
nodes_data = llm_response.get('nodes', [])
|
nodes_data = llm_response.get('nodes', [])
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
# Get full node data
|
# Get full node data
|
||||||
unique_nodes = []
|
unique_nodes = []
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ async def get_edge_contradictions(
|
||||||
contradicted_edges.append(contradicted_edge)
|
contradicted_edges.append(contradicted_edge)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(
|
logger.debug(
|
||||||
f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms'
|
f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue