feat: add entity and entity type nodes to vector db (#126)

* feat: add entity and entity type nodes to vector db

* fix: use uuid5 as entity ids

* fix: id -> uuid and LanceDB collection model
This commit is contained in:
Boris 2024-08-01 14:21:39 +02:00 committed by GitHub
parent 5182051168
commit 26bca0184f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 496 additions and 203 deletions

View file

@ -62,7 +62,7 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
tasks = [
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(
save_data_chunks,

View file

@ -9,7 +9,6 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_traverse import search_traverse
from cognee.modules.search.graph.search_summary import search_summary
from cognee.modules.search.graph.search_similarity import search_similarity
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.shared.utils import send_telemetry
class SearchType(Enum):
@ -46,9 +45,6 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
async def specific_search(query_params: List[SearchParameters]) -> List:
graph_client = await get_graph_engine()
graph = graph_client.graph
search_functions: Dict[SearchType, Callable] = {
SearchType.ADJACENT: search_adjacent,
SearchType.SUMMARY: search_summary,
@ -64,7 +60,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
search_func = search_functions.get(search_param.search_type)
if search_func:
# Schedule the coroutine for execution and store the task
task = search_func(**search_param.params, graph = graph)
task = search_func(**search_param.params)
search_tasks.append(task)
# Use asyncio.gather to run all scheduled tasks concurrently
@ -75,7 +71,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
send_telemetry("cognee.search")
return results
return results[0] if len(results) == 1 else results

View file

@ -50,6 +50,17 @@ class Neo4jAdapter(GraphDBInterface):
async def graph(self):
return await self.get_session()
async def has_node(self, node_id: str) -> bool:
results = self.query(
"""
MATCH (n)
WHERE n.id = $node_id
RETURN COUNT(n) > 0 AS node_exists
""",
{"node_id": node_id}
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
node_id = node_id.replace(":", "_")
@ -157,6 +168,39 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
query = f"""
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
RETURN COUNT(relationship) > 0 AS edge_exists
"""
edge_exists = await self.query(query)
return edge_exists
async def has_edges(self, edges):
query = """
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
"""
try:
params = {
"edges": [{
"from_node": edge[0],
"to_node": edge[1],
"relationship_name": edge[2],
} for edge in edges],
}
results = await self.query(query, params)
return [result["edge_exists"] for result in results]
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
raise error
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
@ -198,8 +242,12 @@ class Neo4jAdapter(GraphDBInterface):
},
} for edge in edges]
results = await self.query(query, dict(edges = edges))
return results
try:
results = await self.query(query, dict(edges = edges))
return results
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
raise error
async def get_edges(self, node_id: str):
query = """
@ -261,8 +309,9 @@ class Neo4jAdapter(GraphDBInterface):
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor)
RETURN predecessor.id AS id
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id AND type(r) = $edge_label
RETURN predecessor.id AS predecessor_id
"""
results = await self.query(
@ -273,11 +322,12 @@ class Neo4jAdapter(GraphDBInterface):
)
)
return [result["id"] for result in results]
return [result["predecessor_id"] for result in results]
else:
query = """
MATCH (node:`{node_id}`)-[r]->(predecessor)
RETURN predecessor.id AS id
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id
RETURN predecessor.id AS predecessor_id
"""
results = await self.query(
@ -287,13 +337,14 @@ class Neo4jAdapter(GraphDBInterface):
)
)
return [result["id"] for result in results]
return [result["predecessor_id"] for result in results]
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor)
RETURN successor.id AS id
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id AND type(r) = $edge_label
RETURN successor.id AS successor_id
"""
results = await self.query(
@ -304,11 +355,12 @@ class Neo4jAdapter(GraphDBInterface):
),
)
return [result["id"] for result in results]
return [result["successor_id"] for result in results]
else:
query = """
MATCH (node:`{node_id}`)<-[r]-(successor)
RETURN successor.id AS id
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id
RETURN successor.id AS successor_id
"""
results = await self.query(
@ -318,7 +370,7 @@ class Neo4jAdapter(GraphDBInterface):
)
)
return [result["id"] for result in results]
return [result["successor_id"] for result in results]
async def get_neighbours(self, node_id: str) -> list[str]:
predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id))

View file

@ -2,6 +2,7 @@
import os
import json
import asyncio
import logging
from typing import Dict, Any, List
import aiofiles
@ -25,6 +26,8 @@ class NetworkXAdapter(GraphDBInterface):
self.filename = filename
async def has_node(self, node_id: str) -> bool:
return self.graph.has_node(node_id)
async def add_node(
self,
@ -45,6 +48,18 @@ class NetworkXAdapter(GraphDBInterface):
async def get_graph(self):
return self.graph
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
return self.graph.has_edge(from_node, to_node, key = edge_label)
async def has_edges(self, edges):
result = []
for (from_node, to_node, edge_label) in edges:
if await self.has_edge(from_node, to_node, edge_label):
result.append((from_node, to_node, edge_label))
return result
async def add_edge(
self,
from_node: str,
@ -154,7 +169,12 @@ class NetworkXAdapter(GraphDBInterface):
if not self.graph.has_node(node_id):
return []
neighbour_ids = list(self.graph.neighbors(node_id))
predecessor_ids, successor_ids = await asyncio.gather(
self.get_predecessor_ids(node_id),
self.get_successor_ids(node_id),
)
neighbour_ids = predecessor_ids + successor_ids
if len(neighbour_ids) == 0:
return []

View file

@ -101,7 +101,7 @@ class LanceDBAdapter(VectorDBInterface):
return [ScoredResult(
id = result["id"],
payload = result["payload"],
score = 1,
score = 0,
) for result in results.to_dict("index").values()]
async def search(
@ -109,7 +109,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 10,
limit: int = 5,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
@ -123,11 +123,25 @@ class LanceDBAdapter(VectorDBInterface):
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
result_values = list(results.to_dict("index").values())
min_value = 100
max_value = 0
for result in result_values:
value = float(result["_distance"])
if value > max_value:
max_value = value
if value < min_value:
min_value = value
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
return [ScoredResult(
id = str(result["id"]),
payload = result["payload"],
score = float(result["_distance"]),
) for result in results.to_dict("index").values()]
score = normalized_values[value_index],
) for value_index, result in enumerate(result_values)]
async def batch_search(
self,

View file

@ -1,8 +1,7 @@
from uuid import UUID
from typing import Any, Dict
from pydantic import BaseModel
class ScoredResult(BaseModel):
id: str
score: float
score: float # Lower score is better
payload: Dict[str, Any]

View file

@ -1,9 +1,12 @@
import logging
from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("QDrantAdapter")
# class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
@ -102,14 +105,17 @@ class QDrantAdapter(VectorDBInterface):
points = [convert_to_qdrant_point(point) for point in data_points]
result = await client.upload_points(
collection_name = collection_name,
points = points
)
await client.close()
return result
try:
result = await client.upload_points(
collection_name = collection_name,
points = points
)
return result
except Exception as error:
logger.error("Error uploading data points to Qdrant: %s", str(error))
raise error
finally:
await client.close()
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
client = self.get_qdrant_client()
@ -122,7 +128,7 @@ class QDrantAdapter(VectorDBInterface):
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = None,
limit: int = 5,
with_vector: bool = False
):
if query_text is None and query_vector is None:

View file

@ -1,10 +1,12 @@
import asyncio
import logging
from typing import List, Optional
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter")
class WeaviateAdapter(VectorDBInterface):
name = "Weaviate"
@ -78,20 +80,25 @@ class WeaviateAdapter(VectorDBInterface):
vector = vector
)
objects = list(map(convert_to_weaviate_data_points, data_points))
data_points = list(map(convert_to_weaviate_data_points, data_points))
collection = self.get_collection(collection_name)
with collection.batch.dynamic() as batch:
for data_row in objects:
batch.add_object(
properties = data_row.properties,
vector = data_row.vector
)
return
# return self.get_collection(collection_name).data.insert_many(objects)
try:
if len(data_points) > 1:
return collection.data.insert_many(data_points)
else:
return collection.data.insert(data_points[0])
# with collection.batch.dynamic() as batch:
# for point in data_points:
# batch.add_object(
# uuid = point.uuid,
# properties = point.properties,
# vector = point.vector
# )
except Exception as error:
logger.error("Error creating data points: %s", str(error))
raise error
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter

View file

@ -2,16 +2,16 @@ You are a top-tier algorithm designed for extracting information in structured f
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
The aim is to achieve simplicity and clarity in the knowledge graph.
# 1. Labeling Nodes
**Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person, always label it as **"Person"**.
- Avoid using more specific terms like "Mathematician" or "Scientist".
- Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
- Don't use too generic terms like "Entity".
**Node IDs**: Never utilize integers as node IDs.
- Node IDs should be names or human-readable identifiers found in the text.
# 2. Handling Numerical Data and Dates
- For example, when you identify an entity representing a date, always label it as **"Date"**.
- For example, when you identify an entity representing a date, make sure it has type **"Date"**.
- Extract the date in the format "YYYY-MM-DD"
- If not possible to extract the whole date, extract month or year, or both if available.
- **Property Format**: Properties must be in a key-value format.
@ -23,4 +23,4 @@ The aim is to achieve simplicity and clarity in the knowledge graph, making it a
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
# 4. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination"""
Adhere to the rules strictly. Non-compliance will result in termination

View file

@ -29,7 +29,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
vector_engine = get_vector_engine()
class Keyword(BaseModel):
id: str
uuid: str
text: str
chunk_id: str
document_id: str
@ -61,7 +61,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
DataPoint[Keyword](
id = str(classification_type_id),
payload = Keyword.parse_obj({
"id": str(classification_type_id),
"uuid": str(classification_type_id),
"text": classification_type_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
@ -100,7 +100,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
DataPoint[Keyword](
id = str(classification_subtype_id),
payload = Keyword.parse_obj({
"id": str(classification_subtype_id),
"uuid": str(classification_subtype_id),
"text": classification_subtype_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
@ -118,9 +118,9 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
)
))
edges.append((
str(classification_type_id),
str(classification_subtype_id),
"contains",
str(classification_type_id),
"is_subtype_of",
dict(
relationship_name = "contains",
source_node_id = str(classification_type_id),

View file

@ -1,25 +1,77 @@
import json
import asyncio
from uuid import uuid5, NAMESPACE_OID
from datetime import datetime
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
from ...processing.chunk_types.DocumentChunk import DocumentChunk
from .extract_knowledge_graph import extract_content_graph
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
class EntityNode(BaseModel):
uuid: str
name: str
type: str
description: str
created_at: datetime
updated_at: datetime
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes]
graph_type_node_ids = list(set(type_ids))
graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids)
existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes}
has_collection = await vector_engine.has_collection(collection_name)
if not has_collection:
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if type_node_id not in processed_nodes:
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
processed_nodes[type_node_id] = True
if entity_node_id not in processed_nodes:
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
processed_nodes[entity_node_id] = True
graph_node_edges = [
(edge.source_node_id, edge.target_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
existing_edges = await graph_engine.has_edges([
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
])
existing_edges_map = {}
existing_nodes_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
existing_nodes_map[edge[0]] = True
graph_nodes = []
graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
@ -28,90 +80,139 @@ async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model:
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_name(node.name)
graph_nodes.append((
node_id,
dict(
id = node_id,
chunk_id = str(chunk.chunk_id),
document_id = str(chunk.document_id),
name = node.name,
type = node.type.lower().capitalize(),
type_node_id = generate_node_id(node.type)
type_node_name = generate_name(node.type)
if node_id not in existing_nodes_map:
node_data = dict(
uuid = node_id,
name = node_name,
type = node_name,
description = node.description,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
))
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains",
dict(
relationship_name = "contains",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
graph_nodes.append((
node_id,
dict(
**node_data,
properties = json.dumps(node.properties),
)
))
type_node_id = generate_node_id(node.type)
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, node_id)),
payload = node_data,
embed_field = "name",
))
if type_node_id not in existing_type_nodes_map:
node_name = node.type.lower().capitalize()
existing_nodes_map[node_id] = True
type_node = dict(
id = type_node_id,
name = node_name,
type = node_name,
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains_entity",
dict(
relationship_name = "contains_entity",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
node_id,
type_node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_edges_map[edge_key] = True
if type_node_id not in existing_nodes_map:
type_node_data = dict(
uuid = type_node_id,
name = type_node_name,
type = type_node_id,
description = type_node_name,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((type_node_id, type_node))
existing_type_nodes_map[type_node_id] = type_node
graph_nodes.append((type_node_id, dict(
**type_node_data,
properties = json.dumps(node.properties)
)))
graph_edges.append((
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, type_node_id)),
payload = type_node_data,
embed_field = "name",
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
type_node_id,
node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_nodes_map[type_node_id] = True
# Add relationship that came from graphs.
for edge in graph.edges:
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = edge.relationship_name,
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
await graph_engine.add_nodes(graph_nodes)
existing_edges_map[edge_key] = True
await graph_engine.add_edges(graph_edges)
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_name(edge.relationship_name)
edge_key = source_node_id + target_node_id + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(
relationship_name = generate_name(edge.relationship_name),
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
properties = json.dumps(edge.properties),
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await vector_engine.create_data_points(collection_name, data_points)
if len(graph_nodes) > 0:
await graph_engine.add_nodes(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks
def generate_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
def generate_node_id(node_id: str) -> str:
return node_id.upper().replace(" ", "_").replace("'", "")
return node_id.lower().replace(" ", "_").replace("'", "")

View file

@ -0,0 +1,46 @@
import asyncio
from queue import Queue
from cognee.modules.pipelines.operations.run_tasks import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
async def pipeline(data_queue):
async def queue_consumer():
while not data_queue.is_closed:
if not data_queue.empty():
yield data_queue.get()
else:
await asyncio.sleep(0.3)
async def add_one(num):
yield num + 1
async def multiply_by_two(num):
yield num * 2
tasks_run = run_tasks([
Task(queue_consumer),
Task(add_one),
Task(multiply_by_two),
])
results = [2, 4, 6, 8, 10, 12, 14, 16, 18]
index = 0
async for result in tasks_run:
print(result)
assert result == results[index]
index += 1
async def main():
data_queue = Queue()
data_queue.is_closed = False
async def queue_producer():
for i in range(0, 10):
data_queue.put(i)
await asyncio.sleep(0.1)
data_queue.is_closed = True
await asyncio.gather(pipeline(data_queue), queue_producer())
if __name__ == "__main__":
asyncio.run(main())

View file

@ -4,11 +4,13 @@ from ..tasks.Task import Task
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
async def run_tasks(tasks: [Task], data):
async def run_tasks(tasks: [Task], data = None):
if len(tasks) == 0:
yield data
return
args = [data] if data is not None else []
running_task = tasks[0]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
@ -19,7 +21,7 @@ async def run_tasks(tasks: [Task], data):
try:
results = []
async_iterator = running_task.run(data)
async_iterator = running_task.run(*args)
async for partial_result in async_iterator:
results.append(partial_result)
@ -51,7 +53,7 @@ async def run_tasks(tasks: [Task], data):
try:
results = []
for partial_result in running_task.run(data):
for partial_result in running_task.run(*args):
results.append(partial_result)
if len(results) == next_task_batch_size:
@ -79,7 +81,7 @@ async def run_tasks(tasks: [Task], data):
elif inspect.iscoroutinefunction(running_task.executable):
logger.info("Running coroutine task: `%s`", running_task.executable.__name__)
try:
task_result = await running_task.run(data)
task_result = await running_task.run(*args)
async for result in run_tasks(leftover_tasks, task_result):
yield result
@ -96,7 +98,7 @@ async def run_tasks(tasks: [Task], data):
elif inspect.isfunction(running_task.executable):
logger.info("Running function task: `%s`", running_task.executable.__name__)
try:
task_result = running_task.run(data)
task_result = running_task.run(*args)
async for result in run_tasks(leftover_tasks, task_result):
yield result

View file

@ -1,21 +1,18 @@
from typing import Union, Dict
import networkx as nx
import asyncio
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
async def search_adjacent(query: str) -> list[(str, str)]:
"""
Find the neighbours of a given node in the graph and return their ids and descriptions.
Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): Unused in this implementation but could be used for future enhancements.
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
- query (str): The query string to filter nodes by.
Returns:
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
"""
node_id = other_param.get("node_id") if other_param else query
node_id = query
if node_id is None:
return {}
@ -23,16 +20,24 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param:
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["id"])
if exact_node is not None and "uuid" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
else:
vector_engine = get_vector_engine()
collection_name = "classification"
data_points = await vector_engine.search(collection_name, query_text = node_id, limit = 5)
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(data_points) == 0:
if len(relevant_results) == 0:
return []
neighbours = await graph_engine.get_neighbours(data_points[0].id)
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
neighbours = []
for neighbour_ids in node_neighbours:
neighbours.extend(neighbour_ids)
return [node["name"] for node in neighbours]
return neighbours

View file

@ -1,18 +1,15 @@
import networkx as nx
from typing import Union
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
async def search_cypher(query:str, graph: Union[nx.Graph, any]):
async def search_cypher(query: str):
"""
Use a Cypher query to search the graph and return the results.
"""
graph_config = get_graph_config()
if graph_config.graph_database_provider == "neo4j":
result = await graph.run(query)
graph_engine = await get_graph_engine()
result = await graph_engine.graph().run(query)
return result
else:
raise ValueError("Unsupported graph engine type.")
raise ValueError("Unsupported search type for the used graph engine.")

View file

@ -1,22 +1,17 @@
from typing import Union, Dict
import networkx as nx
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_similarity(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
async def search_similarity(query: str) -> list[str, str]:
"""
Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
- query (str): The query string to filter nodes by.
Returns:
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'result' attributes.
- list(chunk): A list of objects providing information about the chunks related to query.
"""
vector_engine = get_vector_engine()
similar_results = await vector_engine.search("chunks", query, limit = 5)
results = [{
"text": result.payload["text"],
"chunk_id": result.payload["chunk_id"],
} for result in similar_results]
results = [result.payload for result in similar_results]
return results

View file

@ -1,24 +1,17 @@
from typing import Union, Dict
import networkx as nx
from cognee.shared.data_models import ChunkSummaries
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_summary(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
async def search_summary(query: str) -> list:
"""
Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
- other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements.
- query (str): The query string to filter summaries by.
Returns:
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes.
- list[str, UUID]: A list of objects providing information about the summaries related to query.
"""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5)
summaries = [{
"text": summary.payload["text"],
"chunk_id": summary.payload["chunk_id"],
} for summary in summaries_results]
summaries = [summary.payload for summary in summaries_results]
return summaries

View file

@ -1,21 +1,36 @@
import asyncio
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_traverse(query: str, graph): # graph must be there in order to be compatible with generic call
async def search_traverse(query: str):
node_id = query
rules = set()
graph_engine = await get_graph_engine()
vector_engine = get_vector_engine()
results = await vector_engine.search("classification", query_text = query, limit = 10)
exact_node = await graph_engine.extract_node(node_id)
rules = []
if exact_node is not None and "uuid" in exact_node:
edges = await graph_engine.get_edges(exact_node["uuid"])
if len(results) > 0:
for result in results:
graph_node_id = result.id
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
else:
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
edges = await graph_engine.get_edges(graph_node_id)
if len(relevant_results) > 0:
for result in relevant_results:
graph_node_id = result.id
for edge in edges:
rules.append(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
edges = await graph_engine.get_edges(graph_node_id)
return rules
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
return list(rules)

View file

@ -10,12 +10,14 @@ class Node(BaseModel):
name: str
type: str
description: str
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the node.")
class Edge(BaseModel):
"""Edge in a knowledge graph."""
source_node_id: str
target_node_id: str
relationship_name: str
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the edge.")
class KnowledgeGraph(BaseModel):
"""Knowledge graph."""

View file

@ -1,4 +1,3 @@
import os
import logging
import pathlib
@ -38,21 +37,32 @@ async def main():
await cognee.cognify([dataset_name], root_node_id = "ROOT")
search_results = await cognee.search("TRAVERSE", { "query": "Text" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "Articles" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")

View file

@ -33,21 +33,32 @@ async def main():
await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")

View file

@ -34,21 +34,32 @@ async def main():
await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")

View file

@ -32,21 +32,32 @@ async def main():
await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")