fix: add summaries to the graph
This commit is contained in:
parent
897bbac699
commit
f569088a2e
16 changed files with 127 additions and 58 deletions
|
|
@ -338,7 +338,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return predecessors + successors
|
return predecessors + successors
|
||||||
|
|
||||||
async def get_connections(self, node_id: str) -> list:
|
async def get_connections(self, node_id: UUID) -> list:
|
||||||
predecessors_query = """
|
predecessors_query = """
|
||||||
MATCH (node)<-[relation]-(neighbour)
|
MATCH (node)<-[relation]-(neighbour)
|
||||||
WHERE node.id = $node_id
|
WHERE node.id = $node_id
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from re import A
|
from re import A
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
|
from uuid import UUID
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os as aiofiles_os
|
import aiofiles.os as aiofiles_os
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
@ -130,7 +131,7 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
|
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
|
||||||
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
||||||
|
|
||||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list:
|
async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
|
||||||
if self.graph.has_node(node_id):
|
if self.graph.has_node(node_id):
|
||||||
if edge_label is None:
|
if edge_label is None:
|
||||||
return [
|
return [
|
||||||
|
|
@ -146,7 +147,7 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
async def get_successors(self, node_id: str, edge_label: str = None) -> list:
|
async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:
|
||||||
if self.graph.has_node(node_id):
|
if self.graph.has_node(node_id):
|
||||||
if edge_label is None:
|
if edge_label is None:
|
||||||
return [
|
return [
|
||||||
|
|
@ -175,13 +176,13 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return neighbours
|
return neighbours
|
||||||
|
|
||||||
async def get_connections(self, node_id: str) -> list:
|
async def get_connections(self, node_id: UUID) -> list:
|
||||||
if not self.graph.has_node(node_id):
|
if not self.graph.has_node(node_id):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
node = self.graph.nodes[node_id]
|
node = self.graph.nodes[node_id]
|
||||||
|
|
||||||
if "uuid" not in node:
|
if "id" not in node:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
predecessors, successors = await asyncio.gather(
|
predecessors, successors = await asyncio.gather(
|
||||||
|
|
@ -192,14 +193,14 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
connections = []
|
connections = []
|
||||||
|
|
||||||
for neighbor in predecessors:
|
for neighbor in predecessors:
|
||||||
if "uuid" in neighbor:
|
if "id" in neighbor:
|
||||||
edge_data = self.graph.get_edge_data(neighbor["uuid"], node["uuid"])
|
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
|
||||||
for edge_properties in edge_data.values():
|
for edge_properties in edge_data.values():
|
||||||
connections.append((neighbor, edge_properties, node))
|
connections.append((neighbor, edge_properties, node))
|
||||||
|
|
||||||
for neighbor in successors:
|
for neighbor in successors:
|
||||||
if "uuid" in neighbor:
|
if "id" in neighbor:
|
||||||
edge_data = self.graph.get_edge_data(node["uuid"], neighbor["uuid"])
|
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
|
||||||
for edge_properties in edge_data.values():
|
for edge_properties in edge_data.values():
|
||||||
connections.append((node, edge_properties, neighbor))
|
connections.append((node, edge_properties, neighbor))
|
||||||
|
|
||||||
|
|
@ -245,6 +246,17 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
async with aiofiles.open(file_path, "r") as file:
|
async with aiofiles.open(file_path, "r") as file:
|
||||||
graph_data = json.loads(await file.read())
|
graph_data = json.loads(await file.read())
|
||||||
|
for node in graph_data["nodes"]:
|
||||||
|
node["id"] = UUID(node["id"])
|
||||||
|
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
|
||||||
|
for edge in graph_data["links"]:
|
||||||
|
edge["source"] = UUID(edge["source"])
|
||||||
|
edge["target"] = UUID(edge["target"])
|
||||||
|
edge["source_node_id"] = UUID(edge["source_node_id"])
|
||||||
|
edge["target_node_id"] = UUID(edge["target_node_id"])
|
||||||
|
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
|
||||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
||||||
else:
|
else:
|
||||||
# Log that the file does not exist and an empty graph is initialized
|
# Log that the file does not exist and an empty graph is initialized
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
from falkordb import FalkorDB
|
from falkordb import FalkorDB
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
@ -161,6 +162,35 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
async def extract_nodes(self, data_point_ids: list[str]):
|
async def extract_nodes(self, data_point_ids: list[str]):
|
||||||
return await self.retrieve(data_point_ids)
|
return await self.retrieve(data_point_ids)
|
||||||
|
|
||||||
|
async def get_connections(self, node_id: UUID) -> list:
|
||||||
|
predecessors_query = """
|
||||||
|
MATCH (node)<-[relation]-(neighbour)
|
||||||
|
WHERE node.id = $node_id
|
||||||
|
RETURN neighbour, relation, node
|
||||||
|
"""
|
||||||
|
successors_query = """
|
||||||
|
MATCH (node)-[relation]->(neighbour)
|
||||||
|
WHERE node.id = $node_id
|
||||||
|
RETURN node, relation, neighbour
|
||||||
|
"""
|
||||||
|
|
||||||
|
predecessors, successors = await asyncio.gather(
|
||||||
|
self.query(predecessors_query, dict(node_id = node_id)),
|
||||||
|
self.query(successors_query, dict(node_id = node_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
connections = []
|
||||||
|
|
||||||
|
for neighbour in predecessors:
|
||||||
|
neighbour = neighbour["relation"]
|
||||||
|
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||||
|
|
||||||
|
for neighbour in successors:
|
||||||
|
neighbour = neighbour["relation"]
|
||||||
|
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||||
|
|
||||||
|
return connections
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id = UUID(result.id),
|
id = UUID(result.uuid),
|
||||||
payload = result.properties,
|
payload = result.properties,
|
||||||
score = float(result.metadata.score)
|
score = float(result.metadata.score)
|
||||||
) for result in search_result.objects
|
) for result in search_result.objects
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ class TextChunker():
|
||||||
else:
|
else:
|
||||||
if len(self.paragraph_chunks) == 0:
|
if len(self.paragraph_chunks) == 0:
|
||||||
yield DocumentChunk(
|
yield DocumentChunk(
|
||||||
id = str(chunk_data["chunk_id"]),
|
id = chunk_data["chunk_id"],
|
||||||
text = chunk_data["text"],
|
text = chunk_data["text"],
|
||||||
word_count = chunk_data["word_count"],
|
word_count = chunk_data["word_count"],
|
||||||
is_part_of = self.document,
|
is_part_of = self.document,
|
||||||
|
|
@ -42,7 +42,7 @@ class TextChunker():
|
||||||
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
|
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
|
||||||
try:
|
try:
|
||||||
yield DocumentChunk(
|
yield DocumentChunk(
|
||||||
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
|
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||||
text = chunk_text,
|
text = chunk_text,
|
||||||
word_count = self.chunk_size,
|
word_count = self.chunk_size,
|
||||||
is_part_of = self.document,
|
is_part_of = self.document,
|
||||||
|
|
@ -59,7 +59,7 @@ class TextChunker():
|
||||||
if len(self.paragraph_chunks) > 0:
|
if len(self.paragraph_chunks) > 0:
|
||||||
try:
|
try:
|
||||||
yield DocumentChunk(
|
yield DocumentChunk(
|
||||||
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
|
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||||
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
|
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
|
||||||
word_count = self.chunk_size,
|
word_count = self.chunk_size,
|
||||||
is_part_of = self.document,
|
is_part_of = self.document,
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules import data
|
|
||||||
from cognee.modules.storage.utils import copy_model
|
from cognee.modules.storage.utils import copy_model
|
||||||
|
|
||||||
def get_graph_from_model(data_point: DataPoint, include_root = True):
|
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
|
||||||
nodes = []
|
nodes = []
|
||||||
edges = []
|
edges = []
|
||||||
|
|
||||||
|
|
@ -17,29 +16,55 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
|
||||||
if isinstance(field_value, DataPoint):
|
if isinstance(field_value, DataPoint):
|
||||||
excluded_properties.add(field_name)
|
excluded_properties.add(field_name)
|
||||||
|
|
||||||
property_nodes, property_edges = get_graph_from_model(field_value, True)
|
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
|
||||||
nodes[:0] = property_nodes
|
|
||||||
edges[:0] = property_edges
|
for node in property_nodes:
|
||||||
|
if str(node.id) not in added_nodes:
|
||||||
|
nodes.append(node)
|
||||||
|
added_nodes[str(node.id)] = True
|
||||||
|
|
||||||
|
for edge in property_edges:
|
||||||
|
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append(edge)
|
||||||
|
added_edges[str(edge_key)] = True
|
||||||
|
|
||||||
for property_node in get_own_properties(property_nodes, property_edges):
|
for property_node in get_own_properties(property_nodes, property_edges):
|
||||||
edges.append((data_point.id, property_node.id, field_name, {
|
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||||
"source_node_id": data_point.id,
|
|
||||||
"target_node_id": property_node.id,
|
if str(edge_key) not in added_edges:
|
||||||
"relationship_name": field_name,
|
edges.append((data_point.id, property_node.id, field_name, {
|
||||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
"source_node_id": data_point.id,
|
||||||
}))
|
"target_node_id": property_node.id,
|
||||||
|
"relationship_name": field_name,
|
||||||
|
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
}))
|
||||||
|
added_edges[str(edge_key)] = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(field_value, list):
|
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
|
||||||
if isinstance(field_value[0], DataPoint):
|
excluded_properties.add(field_name)
|
||||||
excluded_properties.add(field_name)
|
|
||||||
|
|
||||||
for item in field_value:
|
for item in field_value:
|
||||||
property_nodes, property_edges = get_graph_from_model(item, True)
|
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
|
||||||
nodes[:0] = property_nodes
|
|
||||||
edges[:0] = property_edges
|
|
||||||
|
|
||||||
for property_node in get_own_properties(property_nodes, property_edges):
|
for node in property_nodes:
|
||||||
|
if str(node.id) not in added_nodes:
|
||||||
|
nodes.append(node)
|
||||||
|
added_nodes[str(node.id)] = True
|
||||||
|
|
||||||
|
for edge in property_edges:
|
||||||
|
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append(edge)
|
||||||
|
added_edges[edge_key] = True
|
||||||
|
|
||||||
|
for property_node in get_own_properties(property_nodes, property_edges):
|
||||||
|
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
edges.append((data_point.id, property_node.id, field_name, {
|
edges.append((data_point.id, property_node.id, field_name, {
|
||||||
"source_node_id": data_point.id,
|
"source_node_id": data_point.id,
|
||||||
"target_node_id": property_node.id,
|
"target_node_id": property_node.id,
|
||||||
|
|
@ -49,7 +74,8 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
|
||||||
"type": "list"
|
"type": "list"
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
continue
|
added_edges[edge_key] = True
|
||||||
|
continue
|
||||||
|
|
||||||
data_point_properties[field_name] = field_value
|
data_point_properties[field_name] = field_value
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from ..tasks.Task import Task
|
||||||
|
|
||||||
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
||||||
|
|
||||||
async def run_tasks_base(tasks: [Task], data = None, user: User = None):
|
async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
|
||||||
if len(tasks) == 0:
|
if len(tasks) == 0:
|
||||||
yield data
|
yield data
|
||||||
return
|
return
|
||||||
|
|
@ -16,7 +16,7 @@ async def run_tasks_base(tasks: [Task], data = None, user: User = None):
|
||||||
|
|
||||||
running_task = tasks[0]
|
running_task = tasks[0]
|
||||||
leftover_tasks = tasks[1:]
|
leftover_tasks = tasks[1:]
|
||||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
|
||||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(running_task.executable):
|
if inspect.isasyncgenfunction(running_task.executable):
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,13 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
|
||||||
|
|
||||||
exact_node = await graph_engine.extract_node(node_id)
|
exact_node = await graph_engine.extract_node(node_id)
|
||||||
|
|
||||||
if exact_node is not None and "uuid" in exact_node:
|
if exact_node is not None and "id" in exact_node:
|
||||||
node_connections = await graph_engine.get_connections(str(exact_node["uuid"]))
|
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
|
||||||
else:
|
else:
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
vector_engine.search("Entity_text", query_text = query, limit = 5),
|
vector_engine.search("Entity_name", query_text = query, limit = 5),
|
||||||
vector_engine.search("EntityType_text", query_text = query, limit = 5),
|
vector_engine.search("EntityType_name", query_text = query, limit = 5),
|
||||||
)
|
)
|
||||||
results = [*results[0], *results[1]]
|
results = [*results[0], *results[1]]
|
||||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||||
|
|
@ -37,7 +37,7 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
|
||||||
return []
|
return []
|
||||||
|
|
||||||
node_connections_results = await asyncio.gather(
|
node_connections_results = await asyncio.gather(
|
||||||
*[graph_engine.get_connections(str(result.payload["uuid"])) for result in relevant_results]
|
*[graph_engine.get_connections(result.id) for result in relevant_results]
|
||||||
)
|
)
|
||||||
|
|
||||||
node_connections = []
|
node_connections = []
|
||||||
|
|
@ -48,10 +48,10 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
|
||||||
unique_node_connections_map = {}
|
unique_node_connections_map = {}
|
||||||
unique_node_connections = []
|
unique_node_connections = []
|
||||||
for node_connection in node_connections:
|
for node_connection in node_connections:
|
||||||
if "uuid" not in node_connection[0] or "uuid" not in node_connection[2]:
|
if "id" not in node_connection[0] or "id" not in node_connection[2]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
unique_id = f"{node_connection[0]['uuid']} {node_connection[1]['relationship_name']} {node_connection[2]['uuid']}"
|
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
|
||||||
|
|
||||||
if unique_id not in unique_node_connections_map:
|
if unique_id not in unique_node_connections_map:
|
||||||
unique_node_connections_map[unique_id] = True
|
unique_node_connections_map[unique_id] = True
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,8 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
|
||||||
added_data_points[str(new_point.id)] = True
|
added_data_points[str(new_point.id)] = True
|
||||||
data_points.append(new_point)
|
data_points.append(new_point)
|
||||||
|
|
||||||
data_points.append(data_point)
|
if (str(data_point.id) not in added_data_points):
|
||||||
|
data_points.append(data_point)
|
||||||
|
|
||||||
return data_points
|
return data_points
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,8 @@ from cognee.modules.data.processing.document_types import Document
|
||||||
|
|
||||||
class TextSummary(DataPoint):
|
class TextSummary(DataPoint):
|
||||||
text: str
|
text: str
|
||||||
chunk: DocumentChunk
|
made_from: DocumentChunk
|
||||||
|
|
||||||
_metadata: dict = {
|
_metadata: dict = {
|
||||||
"index_fields": ["text"],
|
"index_fields": ["text"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||||
from cognee.modules.data.extraction.extract_summary import extract_summary
|
from cognee.modules.data.extraction.extract_summary import extract_summary
|
||||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.tasks.storage.index_data_points import get_data_points_from_model
|
||||||
from .models.TextSummary import TextSummary
|
from .models.TextSummary import TextSummary
|
||||||
|
|
||||||
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
|
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
|
||||||
|
|
@ -17,12 +18,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
|
||||||
|
|
||||||
summaries = [
|
summaries = [
|
||||||
TextSummary(
|
TextSummary(
|
||||||
id = uuid5(chunk.id, "summary"),
|
id = uuid5(chunk.id, "TextSummary"),
|
||||||
chunk = chunk,
|
made_from = chunk,
|
||||||
text = chunk_summaries[chunk_index].summary,
|
text = chunk_summaries[chunk_index].summary,
|
||||||
) for (chunk_index, chunk) in enumerate(data_chunks)
|
) for (chunk_index, chunk) in enumerate(data_chunks)
|
||||||
]
|
]
|
||||||
|
|
||||||
add_data_points(summaries)
|
await add_data_points(summaries)
|
||||||
|
|
||||||
return data_chunks
|
return data_chunks
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,8 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||||
random_node_name = random_node.payload["name"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,8 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||||
random_node_name = random_node.payload["name"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,8 @@ async def main():
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||||
random_node_name = random_node.payload["name"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
|
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,8 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||||
random_node_name = random_node.payload["name"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
|
|
||||||
|
|
@ -35,8 +35,8 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||||
random_node_name = random_node.payload["name"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue