Merge remote-tracking branch 'origin/main' into COG-206

This commit is contained in:
Boris Arzentar 2024-08-01 14:25:28 +02:00
commit 2717272403
29 changed files with 558 additions and 242 deletions

View file

@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify
from .api.v1.datasets.datasets import datasets
from .api.v1.search.search import search, SearchType
from .api.v1.prune import prune
# Pipelines
from .modules import pipelines

View file

@ -8,7 +8,7 @@ import logging
import sentry_sdk
from typing import Dict, Any, List, Union, Optional, Literal
from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

View file

@ -1,18 +1,11 @@
import asyncio
import hashlib
import logging
import uuid
from typing import Union
from fastapi_users import fastapi_users
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \
get_user_permissions, get_async_session_context, fast_api_users_init
# from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
# from cognee.infrastructure.databases.relational.user_authentication.users import get_user_permissions, fastapi_users
get_async_session_context, fast_api_users_init
from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
@ -62,8 +55,6 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session)
if out:
async with update_status_lock:
task_status = get_task_status([dataset_name])
@ -89,9 +80,9 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
root_node_id = "ROOT"
tasks = [
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(
save_data_chunks,

View file

@ -11,7 +11,6 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_traverse import search_traverse
from cognee.modules.search.graph.search_summary import search_summary
from cognee.modules.search.graph.search_similarity import search_similarity
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.shared.utils import send_telemetry
class SearchType(Enum):
@ -63,9 +62,6 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
async def specific_search(query_params: List[SearchParameters]) -> List:
graph_client = await get_graph_engine()
graph = graph_client.graph
search_functions: Dict[SearchType, Callable] = {
SearchType.ADJACENT: search_adjacent,
SearchType.SUMMARY: search_summary,
@ -81,7 +77,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
search_func = search_functions.get(search_param.search_type)
if search_func:
# Schedule the coroutine for execution and store the task
task = search_func(**search_param.params, graph = graph)
task = search_func(**search_param.params)
search_tasks.append(task)
# Use asyncio.gather to run all scheduled tasks concurrently
@ -92,7 +88,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
send_telemetry("cognee.search")
return results
return results[0] if len(results) == 1 else results

View file

@ -50,6 +50,17 @@ class Neo4jAdapter(GraphDBInterface):
async def graph(self):
return await self.get_session()
async def has_node(self, node_id: str) -> bool:
results = self.query(
"""
MATCH (n)
WHERE n.id = $node_id
RETURN COUNT(n) > 0 AS node_exists
""",
{"node_id": node_id}
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
node_id = node_id.replace(":", "_")
@ -157,6 +168,39 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
query = f"""
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
RETURN COUNT(relationship) > 0 AS edge_exists
"""
edge_exists = await self.query(query)
return edge_exists
async def has_edges(self, edges):
query = """
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
"""
try:
params = {
"edges": [{
"from_node": edge[0],
"to_node": edge[1],
"relationship_name": edge[2],
} for edge in edges],
}
results = await self.query(query, params)
return [result["edge_exists"] for result in results]
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
raise error
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
@ -198,8 +242,12 @@ class Neo4jAdapter(GraphDBInterface):
},
} for edge in edges]
results = await self.query(query, dict(edges = edges))
return results
try:
results = await self.query(query, dict(edges = edges))
return results
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
raise error
async def get_edges(self, node_id: str):
query = """
@ -261,8 +309,9 @@ class Neo4jAdapter(GraphDBInterface):
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor)
RETURN predecessor.id AS id
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id AND type(r) = $edge_label
RETURN predecessor.id AS predecessor_id
"""
results = await self.query(
@ -273,11 +322,12 @@ class Neo4jAdapter(GraphDBInterface):
)
)
return [result["id"] for result in results]
return [result["predecessor_id"] for result in results]
else:
query = """
MATCH (node:`{node_id}`)-[r]->(predecessor)
RETURN predecessor.id AS id
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id
RETURN predecessor.id AS predecessor_id
"""
results = await self.query(
@ -287,13 +337,14 @@ class Neo4jAdapter(GraphDBInterface):
)
)
return [result["id"] for result in results]
return [result["predecessor_id"] for result in results]
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor)
RETURN successor.id AS id
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id AND type(r) = $edge_label
RETURN successor.id AS successor_id
"""
results = await self.query(
@ -304,11 +355,12 @@ class Neo4jAdapter(GraphDBInterface):
),
)
return [result["id"] for result in results]
return [result["successor_id"] for result in results]
else:
query = """
MATCH (node:`{node_id}`)<-[r]-(successor)
RETURN successor.id AS id
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id
RETURN successor.id AS successor_id
"""
results = await self.query(
@ -318,12 +370,12 @@ class Neo4jAdapter(GraphDBInterface):
)
)
return [result["id"] for result in results]
return [result["successor_id"] for result in results]
async def get_neighbours(self, node_id: str) -> list[str]:
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id))
predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id))
return [*results[0], *results[1]]
return [*predecessor_ids, *successor_ids]
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
query = f"""

View file

@ -2,6 +2,7 @@
import os
import json
import asyncio
import logging
from typing import Dict, Any, List
import aiofiles
@ -25,6 +26,8 @@ class NetworkXAdapter(GraphDBInterface):
self.filename = filename
async def has_node(self, node_id: str) -> bool:
return self.graph.has_node(node_id)
async def add_node(
self,
@ -45,6 +48,18 @@ class NetworkXAdapter(GraphDBInterface):
async def get_graph(self):
return self.graph
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
return self.graph.has_edge(from_node, to_node, key = edge_label)
async def has_edges(self, edges):
result = []
for (from_node, to_node, edge_label) in edges:
if await self.has_edge(from_node, to_node, edge_label):
result.append((from_node, to_node, edge_label))
return result
async def add_edge(
self,
from_node: str,
@ -154,7 +169,12 @@ class NetworkXAdapter(GraphDBInterface):
if not self.graph.has_node(node_id):
return []
neighbour_ids = list(self.graph.neighbors(node_id))
predecessor_ids, successor_ids = await asyncio.gather(
self.get_predecessor_ids(node_id),
self.get_successor_ids(node_id),
)
neighbour_ids = predecessor_ids + successor_ids
if len(neighbour_ids) == 0:
return []

View file

@ -101,7 +101,7 @@ class LanceDBAdapter(VectorDBInterface):
return [ScoredResult(
id = result["id"],
payload = result["payload"],
score = 1,
score = 0,
) for result in results.to_dict("index").values()]
async def search(
@ -109,7 +109,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 10,
limit: int = 5,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
@ -123,11 +123,25 @@ class LanceDBAdapter(VectorDBInterface):
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
result_values = list(results.to_dict("index").values())
min_value = 100
max_value = 0
for result in result_values:
value = float(result["_distance"])
if value > max_value:
max_value = value
if value < min_value:
min_value = value
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
return [ScoredResult(
id = str(result["id"]),
payload = result["payload"],
score = float(result["_distance"]),
) for result in results.to_dict("index").values()]
score = normalized_values[value_index],
) for value_index, result in enumerate(result_values)]
async def batch_search(
self,

View file

@ -1,8 +1,7 @@
from uuid import UUID
from typing import Any, Dict
from pydantic import BaseModel
class ScoredResult(BaseModel):
id: str
score: float
score: float # Lower score is better
payload: Dict[str, Any]

View file

@ -1,9 +1,12 @@
import logging
from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("QDrantAdapter")
# class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
@ -102,14 +105,17 @@ class QDrantAdapter(VectorDBInterface):
points = [convert_to_qdrant_point(point) for point in data_points]
result = await client.upload_points(
collection_name = collection_name,
points = points
)
await client.close()
return result
try:
result = await client.upload_points(
collection_name = collection_name,
points = points
)
return result
except Exception as error:
logger.error("Error uploading data points to Qdrant: %s", str(error))
raise error
finally:
await client.close()
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
client = self.get_qdrant_client()
@ -122,7 +128,7 @@ class QDrantAdapter(VectorDBInterface):
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = None,
limit: int = 5,
with_vector: bool = False
):
if query_text is None and query_vector is None:

View file

@ -1,10 +1,12 @@
import asyncio
import logging
from typing import List, Optional
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter")
class WeaviateAdapter(VectorDBInterface):
name = "Weaviate"
@ -78,20 +80,25 @@ class WeaviateAdapter(VectorDBInterface):
vector = vector
)
objects = list(map(convert_to_weaviate_data_points, data_points))
data_points = list(map(convert_to_weaviate_data_points, data_points))
collection = self.get_collection(collection_name)
with collection.batch.dynamic() as batch:
for data_row in objects:
batch.add_object(
properties = data_row.properties,
vector = data_row.vector
)
return
# return self.get_collection(collection_name).data.insert_many(objects)
try:
if len(data_points) > 1:
return collection.data.insert_many(data_points)
else:
return collection.data.insert(data_points[0])
# with collection.batch.dynamic() as batch:
# for point in data_points:
# batch.add_object(
# uuid = point.uuid,
# properties = point.properties,
# vector = point.vector
# )
except Exception as error:
logger.error("Error creating data points: %s", str(error))
raise error
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter

View file

@ -2,16 +2,16 @@ You are a top-tier algorithm designed for extracting information in structured f
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
The aim is to achieve simplicity and clarity in the knowledge graph.
# 1. Labeling Nodes
**Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person, always label it as **"Person"**.
- Avoid using more specific terms like "Mathematician" or "Scientist".
- Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
- Don't use too generic terms like "Entity".
**Node IDs**: Never utilize integers as node IDs.
- Node IDs should be names or human-readable identifiers found in the text.
# 2. Handling Numerical Data and Dates
- For example, when you identify an entity representing a date, always label it as **"Date"**.
- For example, when you identify an entity representing a date, make sure it has type **"Date"**.
- Extract the date in the format "YYYY-MM-DD"
- If not possible to extract the whole date, extract month or year, or both if available.
- **Property Format**: Properties must be in a key-value format.
@ -23,4 +23,4 @@ The aim is to achieve simplicity and clarity in the knowledge graph, making it a
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
# 4. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination"""
Adhere to the rules strictly. Non-compliance will result in termination

View file

@ -29,7 +29,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
vector_engine = get_vector_engine()
class Keyword(BaseModel):
id: str
uuid: str
text: str
chunk_id: str
document_id: str
@ -61,7 +61,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
DataPoint[Keyword](
id = str(classification_type_id),
payload = Keyword.parse_obj({
"id": str(classification_type_id),
"uuid": str(classification_type_id),
"text": classification_type_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
@ -100,7 +100,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
DataPoint[Keyword](
id = str(classification_subtype_id),
payload = Keyword.parse_obj({
"id": str(classification_subtype_id),
"uuid": str(classification_subtype_id),
"text": classification_subtype_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
@ -118,9 +118,9 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
)
))
edges.append((
str(classification_type_id),
str(classification_subtype_id),
"contains",
str(classification_type_id),
"is_subtype_of",
dict(
relationship_name = "contains",
source_node_id = str(classification_type_id),

View file

@ -1,25 +1,77 @@
import json
import asyncio
from uuid import uuid5, NAMESPACE_OID
from datetime import datetime
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
from ...processing.chunk_types.DocumentChunk import DocumentChunk
from .extract_knowledge_graph import extract_content_graph
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
class EntityNode(BaseModel):
uuid: str
name: str
type: str
description: str
created_at: datetime
updated_at: datetime
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes]
graph_type_node_ids = list(set(type_ids))
graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids)
existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes}
has_collection = await vector_engine.has_collection(collection_name)
if not has_collection:
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if type_node_id not in processed_nodes:
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
processed_nodes[type_node_id] = True
if entity_node_id not in processed_nodes:
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
processed_nodes[entity_node_id] = True
graph_node_edges = [
(edge.source_node_id, edge.target_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
existing_edges = await graph_engine.has_edges([
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
])
existing_edges_map = {}
existing_nodes_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
existing_nodes_map[edge[0]] = True
graph_nodes = []
graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
@ -28,90 +80,139 @@ async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model:
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_name(node.name)
graph_nodes.append((
node_id,
dict(
id = node_id,
chunk_id = str(chunk.chunk_id),
document_id = str(chunk.document_id),
name = node.name,
type = node.type.lower().capitalize(),
type_node_id = generate_node_id(node.type)
type_node_name = generate_name(node.type)
if node_id not in existing_nodes_map:
node_data = dict(
uuid = node_id,
name = node_name,
type = node_name,
description = node.description,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
))
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains",
dict(
relationship_name = "contains",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
graph_nodes.append((
node_id,
dict(
**node_data,
properties = json.dumps(node.properties),
)
))
type_node_id = generate_node_id(node.type)
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, node_id)),
payload = node_data,
embed_field = "name",
))
if type_node_id not in existing_type_nodes_map:
node_name = node.type.lower().capitalize()
existing_nodes_map[node_id] = True
type_node = dict(
id = type_node_id,
name = node_name,
type = node_name,
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains_entity",
dict(
relationship_name = "contains_entity",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
node_id,
type_node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_edges_map[edge_key] = True
if type_node_id not in existing_nodes_map:
type_node_data = dict(
uuid = type_node_id,
name = type_node_name,
type = type_node_id,
description = type_node_name,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((type_node_id, type_node))
existing_type_nodes_map[type_node_id] = type_node
graph_nodes.append((type_node_id, dict(
**type_node_data,
properties = json.dumps(node.properties)
)))
graph_edges.append((
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, type_node_id)),
payload = type_node_data,
embed_field = "name",
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
type_node_id,
node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_nodes_map[type_node_id] = True
# Add relationship that came from graphs.
for edge in graph.edges:
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = edge.relationship_name,
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
await graph_engine.add_nodes(graph_nodes)
existing_edges_map[edge_key] = True
await graph_engine.add_edges(graph_edges)
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_name(edge.relationship_name)
edge_key = source_node_id + target_node_id + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(
relationship_name = generate_name(edge.relationship_name),
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
properties = json.dumps(edge.properties),
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await vector_engine.create_data_points(collection_name, data_points)
if len(graph_nodes) > 0:
await graph_engine.add_nodes(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks
def generate_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
def generate_node_id(node_id: str) -> str:
return node_id.upper().replace(" ", "_").replace("'", "")
return node_id.lower().replace(" ", "_").replace("'", "")

View file

@ -1,2 +1,3 @@
from .tasks.Task import Task
from .operations.run_tasks import run_tasks
from .operations.run_parallel import run_tasks_parallel

View file

@ -8,27 +8,29 @@ async def main():
for i in range(num):
yield i + 1
async def add_one(num):
yield num + 1
async def multiply_by_two(nums):
async def add_one(nums):
for num in nums:
yield num * 2
yield num + 1
async def add_one_to_batched_data(num):
async def multiply_by_two(num):
yield num * 2
async def add_one_single(num):
yield num + 1
pipeline = run_tasks([
Task(number_generator, task_config = {"batch_size": 1}),
Task(number_generator),
Task(add_one, task_config = {"batch_size": 5}),
Task(multiply_by_two, task_config = {"batch_size": 1}),
Task(add_one_to_batched_data),
Task(add_one_single),
], 10)
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
index = 0
async for result in pipeline:
print("\n")
print(result)
print("\n")
assert result == results[index]
index += 1
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,46 @@
import asyncio
from queue import Queue
from cognee.modules.pipelines.operations.run_tasks import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
async def pipeline(data_queue):
async def queue_consumer():
while not data_queue.is_closed:
if not data_queue.empty():
yield data_queue.get()
else:
await asyncio.sleep(0.3)
async def add_one(num):
yield num + 1
async def multiply_by_two(num):
yield num * 2
tasks_run = run_tasks([
Task(queue_consumer),
Task(add_one),
Task(multiply_by_two),
])
results = [2, 4, 6, 8, 10, 12, 14, 16, 18]
index = 0
async for result in tasks_run:
print(result)
assert result == results[index]
index += 1
async def main():
data_queue = Queue()
data_queue.is_closed = False
async def queue_producer():
for i in range(0, 10):
data_queue.put(i)
await asyncio.sleep(0.1)
data_queue.is_closed = True
await asyncio.gather(pipeline(data_queue), queue_producer())
if __name__ == "__main__":
asyncio.run(main())

View file

@ -4,29 +4,30 @@ from ..tasks.Task import Task
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
async def run_tasks(tasks: [Task], data):
async def run_tasks(tasks: [Task], data = None):
if len(tasks) == 0:
yield data
return
args = [data] if data is not None else []
running_task = tasks[0]
batch_size = running_task.task_config["batch_size"]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable):
logger.info(f"Running async generator task: `{running_task.executable.__name__}`")
logger.info("Running async generator task: `%s`", running_task.executable.__name__)
try:
results = []
async_iterator = running_task.run(data)
async_iterator = running_task.run(*args)
async for partial_result in async_iterator:
results.append(partial_result)
if len(results) == batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
if len(results) == next_task_batch_size:
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
yield result
results = []
@ -37,7 +38,7 @@ async def run_tasks(tasks: [Task], data):
results = []
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`")
logger.info("Finished async generator task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running async generator task: `%s`\n%s\n",
@ -48,15 +49,15 @@ async def run_tasks(tasks: [Task], data):
raise error
elif inspect.isgeneratorfunction(running_task.executable):
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
logger.info("Running generator task: `%s`", running_task.executable.__name__)
try:
results = []
for partial_result in running_task.run(data):
for partial_result in running_task.run(*args):
results.append(partial_result)
if len(results) == batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
if len(results) == next_task_batch_size:
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
yield result
results = []
@ -67,7 +68,7 @@ async def run_tasks(tasks: [Task], data):
results = []
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
logger.info("Finished generator task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running generator task: `%s`\n%s\n",
@ -78,13 +79,35 @@ async def run_tasks(tasks: [Task], data):
raise error
elif inspect.iscoroutinefunction(running_task.executable):
task_result = await running_task.run(data)
logger.info("Running coroutine task: `%s`", running_task.executable.__name__)
try:
task_result = await running_task.run(*args)
async for result in run_tasks(leftover_tasks, task_result):
yield result
async for result in run_tasks(leftover_tasks, task_result):
yield result
logger.info("Finished coroutine task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running coroutine task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)
elif inspect.isfunction(running_task.executable):
task_result = running_task.run(data)
logger.info("Running function task: `%s`", running_task.executable.__name__)
try:
task_result = running_task.run(*args)
async for result in run_tasks(leftover_tasks, task_result):
yield result
async for result in run_tasks(leftover_tasks, task_result):
yield result
logger.info("Finished function task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running function task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)

View file

@ -1,21 +1,18 @@
from typing import Union, Dict
import networkx as nx
import asyncio
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
async def search_adjacent(query: str) -> list[(str, str)]:
"""
Find the neighbours of a given node in the graph and return their ids and descriptions.
Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): Unused in this implementation but could be used for future enhancements.
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
- query (str): The query string to filter nodes by.
Returns:
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
"""
node_id = other_param.get("node_id") if other_param else query
node_id = query
if node_id is None:
return {}
@ -23,16 +20,24 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param:
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["id"])
if exact_node is not None and "uuid" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
else:
vector_engine = get_vector_engine()
collection_name = "classification"
data_points = await vector_engine.search(collection_name, query_text = node_id, limit = 5)
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(data_points) == 0:
if len(relevant_results) == 0:
return []
neighbours = await graph_engine.get_neighbours(data_points[0].id)
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
neighbours = []
for neighbour_ids in node_neighbours:
neighbours.extend(neighbour_ids)
return [node["name"] for node in neighbours]
return neighbours

View file

@ -1,18 +1,15 @@
import networkx as nx
from typing import Union
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
async def search_cypher(query:str, graph: Union[nx.Graph, any]):
async def search_cypher(query: str):
"""
Use a Cypher query to search the graph and return the results.
"""
graph_config = get_graph_config()
if graph_config.graph_database_provider == "neo4j":
result = await graph.run(query)
graph_engine = await get_graph_engine()
result = await graph_engine.graph().run(query)
return result
else:
raise ValueError("Unsupported graph engine type.")
raise ValueError("Unsupported search type for the used graph engine.")

View file

@ -1,22 +1,17 @@
from typing import Union, Dict
import networkx as nx
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_similarity(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
async def search_similarity(query: str) -> list[str, str]:
"""
Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
- query (str): The query string to filter nodes by.
Returns:
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'result' attributes.
- list(chunk): A list of objects providing information about the chunks related to query.
"""
vector_engine = get_vector_engine()
similar_results = await vector_engine.search("chunks", query, limit = 5)
results = [{
"text": result.payload["text"],
"chunk_id": result.payload["chunk_id"],
} for result in similar_results]
results = [result.payload for result in similar_results]
return results

View file

@ -1,24 +1,17 @@
from typing import Union, Dict
import networkx as nx
from cognee.shared.data_models import ChunkSummaries
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_summary(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
async def search_summary(query: str) -> list:
"""
Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
- other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements.
- query (str): The query string to filter summaries by.
Returns:
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes.
- list[str, UUID]: A list of objects providing information about the summaries related to query.
"""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5)
summaries = [{
"text": summary.payload["text"],
"chunk_id": summary.payload["chunk_id"],
} for summary in summaries_results]
summaries = [summary.payload for summary in summaries_results]
return summaries

View file

@ -1,21 +1,36 @@
import asyncio
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_traverse(query: str, graph): # graph must be there in order to be compatible with generic call
async def search_traverse(query: str):
node_id = query
rules = set()
graph_engine = await get_graph_engine()
vector_engine = get_vector_engine()
results = await vector_engine.search("classification", query_text = query, limit = 10)
exact_node = await graph_engine.extract_node(node_id)
rules = []
if exact_node is not None and "uuid" in exact_node:
edges = await graph_engine.get_edges(exact_node["uuid"])
if len(results) > 0:
for result in results:
graph_node_id = result.id
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
else:
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
edges = await graph_engine.get_edges(graph_node_id)
if len(relevant_results) > 0:
for result in relevant_results:
graph_node_id = result.id
for edge in edges:
rules.append(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
edges = await graph_engine.get_edges(graph_node_id)
return rules
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
return list(rules)

5
cognee/pipelines.py Normal file
View file

@ -0,0 +1,5 @@
# Don't add any more code here, this file is used only for the purpose
# of enabling imports from `cognee.pipelines` module.
# `from cognee.pipelines import Task` for example.
from .modules.pipelines import *

View file

@ -10,12 +10,14 @@ class Node(BaseModel):
name: str
type: str
description: str
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the node.")
class Edge(BaseModel):
"""Edge in a knowledge graph."""
source_node_id: str
target_node_id: str
relationship_name: str
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the edge.")
class KnowledgeGraph(BaseModel):
"""Knowledge graph."""

View file

@ -1,4 +1,3 @@
import os
import logging
import pathlib
@ -38,21 +37,32 @@ async def main():
await cognee.cognify([dataset_name], root_node_id = "ROOT")
search_results = await cognee.search("TRAVERSE", { "query": "Text" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "Articles" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")

View file

@ -33,21 +33,32 @@ async def main():
await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")

View file

@ -34,21 +34,32 @@ async def main():
await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")

View file

@ -32,21 +32,32 @@ async def main():
await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
assert len(search_results) != 0, "ROOT node has no neighbours."
print("\n\nROOT node has neighbours.\n")
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\Large language model query found neighbours.\n")
for result in search_results:
print(f"{result}\n")