Merge remote-tracking branch 'origin/main' into COG-206

This commit is contained in:
Boris Arzentar 2024-08-01 14:25:28 +02:00
commit 2717272403
29 changed files with 558 additions and 242 deletions

View file

@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify
from .api.v1.datasets.datasets import datasets from .api.v1.datasets.datasets import datasets
from .api.v1.search.search import search, SearchType from .api.v1.search.search import search, SearchType
from .api.v1.prune import prune from .api.v1.prune import prune
# Pipelines
from .modules import pipelines

View file

@ -8,7 +8,7 @@ import logging
import sentry_sdk import sentry_sdk
from typing import Dict, Any, List, Union, Optional, Literal from typing import Dict, Any, List, Union, Optional, Literal
from typing_extensions import Annotated from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -1,18 +1,11 @@
import asyncio import asyncio
import hashlib
import logging import logging
import uuid import uuid
from typing import Union from typing import Union
from fastapi_users import fastapi_users
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config
from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \ from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \
get_user_permissions, get_async_session_context, fast_api_users_init get_async_session_context, fast_api_users_init
# from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
# from cognee.infrastructure.databases.relational.user_authentication.users import get_user_permissions, fastapi_users
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
@ -62,8 +55,6 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session) out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session)
if out: if out:
async with update_status_lock: async with update_status_lock:
task_status = get_task_status([dataset_name]) task_status = get_task_status([dataset_name])
@ -89,9 +80,9 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
root_node_id = "ROOT" root_node_id = "ROOT"
tasks = [ tasks = [
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task( Task(
save_data_chunks, save_data_chunks,

View file

@ -11,7 +11,6 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_traverse import search_traverse from cognee.modules.search.vector.search_traverse import search_traverse
from cognee.modules.search.graph.search_summary import search_summary from cognee.modules.search.graph.search_summary import search_summary
from cognee.modules.search.graph.search_similarity import search_similarity from cognee.modules.search.graph.search_similarity import search_similarity
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
class SearchType(Enum): class SearchType(Enum):
@ -63,9 +62,6 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
async def specific_search(query_params: List[SearchParameters]) -> List: async def specific_search(query_params: List[SearchParameters]) -> List:
graph_client = await get_graph_engine()
graph = graph_client.graph
search_functions: Dict[SearchType, Callable] = { search_functions: Dict[SearchType, Callable] = {
SearchType.ADJACENT: search_adjacent, SearchType.ADJACENT: search_adjacent,
SearchType.SUMMARY: search_summary, SearchType.SUMMARY: search_summary,
@ -81,7 +77,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
search_func = search_functions.get(search_param.search_type) search_func = search_functions.get(search_param.search_type)
if search_func: if search_func:
# Schedule the coroutine for execution and store the task # Schedule the coroutine for execution and store the task
task = search_func(**search_param.params, graph = graph) task = search_func(**search_param.params)
search_tasks.append(task) search_tasks.append(task)
# Use asyncio.gather to run all scheduled tasks concurrently # Use asyncio.gather to run all scheduled tasks concurrently
@ -92,7 +88,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
send_telemetry("cognee.search") send_telemetry("cognee.search")
return results return results[0] if len(results) == 1 else results

View file

@ -50,6 +50,17 @@ class Neo4jAdapter(GraphDBInterface):
async def graph(self): async def graph(self):
return await self.get_session() return await self.get_session()
async def has_node(self, node_id: str) -> bool:
results = self.query(
"""
MATCH (n)
WHERE n.id = $node_id
RETURN COUNT(n) > 0 AS node_exists
""",
{"node_id": node_id}
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None): async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
node_id = node_id.replace(":", "_") node_id = node_id.replace(":", "_")
@ -157,6 +168,39 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params) return await self.query(query, params)
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
query = f"""
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
RETURN COUNT(relationship) > 0 AS edge_exists
"""
edge_exists = await self.query(query)
return edge_exists
async def has_edges(self, edges):
query = """
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
"""
try:
params = {
"edges": [{
"from_node": edge[0],
"to_node": edge[1],
"relationship_name": edge[2],
} for edge in edges],
}
results = await self.query(query, params)
return [result["edge_exists"] for result in results]
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
raise error
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties) serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_") from_node = from_node.replace(":", "_")
@ -198,8 +242,12 @@ class Neo4jAdapter(GraphDBInterface):
}, },
} for edge in edges] } for edge in edges]
results = await self.query(query, dict(edges = edges)) try:
return results results = await self.query(query, dict(edges = edges))
return results
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
raise error
async def get_edges(self, node_id: str): async def get_edges(self, node_id: str):
query = """ query = """
@ -261,8 +309,9 @@ class Neo4jAdapter(GraphDBInterface):
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]: async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None: if edge_label is not None:
query = """ query = """
MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor) MATCH (node)<-[r]-(predecessor)
RETURN predecessor.id AS id WHERE node.id = $node_id AND type(r) = $edge_label
RETURN predecessor.id AS predecessor_id
""" """
results = await self.query( results = await self.query(
@ -273,11 +322,12 @@ class Neo4jAdapter(GraphDBInterface):
) )
) )
return [result["id"] for result in results] return [result["predecessor_id"] for result in results]
else: else:
query = """ query = """
MATCH (node:`{node_id}`)-[r]->(predecessor) MATCH (node)<-[r]-(predecessor)
RETURN predecessor.id AS id WHERE node.id = $node_id
RETURN predecessor.id AS predecessor_id
""" """
results = await self.query( results = await self.query(
@ -287,13 +337,14 @@ class Neo4jAdapter(GraphDBInterface):
) )
) )
return [result["id"] for result in results] return [result["predecessor_id"] for result in results]
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]: async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None: if edge_label is not None:
query = """ query = """
MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor) MATCH (node)-[r]->(successor)
RETURN successor.id AS id WHERE node.id = $node_id AND type(r) = $edge_label
RETURN successor.id AS successor_id
""" """
results = await self.query( results = await self.query(
@ -304,11 +355,12 @@ class Neo4jAdapter(GraphDBInterface):
), ),
) )
return [result["id"] for result in results] return [result["successor_id"] for result in results]
else: else:
query = """ query = """
MATCH (node:`{node_id}`)<-[r]-(successor) MATCH (node)-[r]->(successor)
RETURN successor.id AS id WHERE node.id = $node_id
RETURN successor.id AS successor_id
""" """
results = await self.query( results = await self.query(
@ -318,12 +370,12 @@ class Neo4jAdapter(GraphDBInterface):
) )
) )
return [result["id"] for result in results] return [result["successor_id"] for result in results]
async def get_neighbours(self, node_id: str) -> list[str]: async def get_neighbours(self, node_id: str) -> list[str]:
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id)) predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id))
return [*results[0], *results[1]] return [*predecessor_ids, *successor_ids]
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None: async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
query = f""" query = f"""

View file

@ -2,6 +2,7 @@
import os import os
import json import json
import asyncio
import logging import logging
from typing import Dict, Any, List from typing import Dict, Any, List
import aiofiles import aiofiles
@ -25,6 +26,8 @@ class NetworkXAdapter(GraphDBInterface):
self.filename = filename self.filename = filename
async def has_node(self, node_id: str) -> bool:
return self.graph.has_node(node_id)
async def add_node( async def add_node(
self, self,
@ -45,6 +48,18 @@ class NetworkXAdapter(GraphDBInterface):
async def get_graph(self): async def get_graph(self):
return self.graph return self.graph
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
return self.graph.has_edge(from_node, to_node, key = edge_label)
async def has_edges(self, edges):
result = []
for (from_node, to_node, edge_label) in edges:
if await self.has_edge(from_node, to_node, edge_label):
result.append((from_node, to_node, edge_label))
return result
async def add_edge( async def add_edge(
self, self,
from_node: str, from_node: str,
@ -154,7 +169,12 @@ class NetworkXAdapter(GraphDBInterface):
if not self.graph.has_node(node_id): if not self.graph.has_node(node_id):
return [] return []
neighbour_ids = list(self.graph.neighbors(node_id)) predecessor_ids, successor_ids = await asyncio.gather(
self.get_predecessor_ids(node_id),
self.get_successor_ids(node_id),
)
neighbour_ids = predecessor_ids + successor_ids
if len(neighbour_ids) == 0: if len(neighbour_ids) == 0:
return [] return []

View file

@ -101,7 +101,7 @@ class LanceDBAdapter(VectorDBInterface):
return [ScoredResult( return [ScoredResult(
id = result["id"], id = result["id"],
payload = result["payload"], payload = result["payload"],
score = 1, score = 0,
) for result in results.to_dict("index").values()] ) for result in results.to_dict("index").values()]
async def search( async def search(
@ -109,7 +109,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str, collection_name: str,
query_text: str = None, query_text: str = None,
query_vector: List[float] = None, query_vector: List[float] = None,
limit: int = 10, limit: int = 5,
with_vector: bool = False, with_vector: bool = False,
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
@ -123,11 +123,25 @@ class LanceDBAdapter(VectorDBInterface):
results = await collection.vector_search(query_vector).limit(limit).to_pandas() results = await collection.vector_search(query_vector).limit(limit).to_pandas()
result_values = list(results.to_dict("index").values())
min_value = 100
max_value = 0
for result in result_values:
value = float(result["_distance"])
if value > max_value:
max_value = value
if value < min_value:
min_value = value
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
return [ScoredResult( return [ScoredResult(
id = str(result["id"]), id = str(result["id"]),
payload = result["payload"], payload = result["payload"],
score = float(result["_distance"]), score = normalized_values[value_index],
) for result in results.to_dict("index").values()] ) for value_index, result in enumerate(result_values)]
async def batch_search( async def batch_search(
self, self,

View file

@ -1,8 +1,7 @@
from uuid import UUID
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel
class ScoredResult(BaseModel): class ScoredResult(BaseModel):
id: str id: str
score: float score: float # Lower score is better
payload: Dict[str, Any] payload: Dict[str, Any]

View file

@ -1,9 +1,12 @@
import logging
from typing import List, Dict, Optional from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("QDrantAdapter")
# class CollectionConfig(BaseModel, extra = "forbid"): # class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" ) # vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration") # hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
@ -102,14 +105,17 @@ class QDrantAdapter(VectorDBInterface):
points = [convert_to_qdrant_point(point) for point in data_points] points = [convert_to_qdrant_point(point) for point in data_points]
result = await client.upload_points( try:
collection_name = collection_name, result = await client.upload_points(
points = points collection_name = collection_name,
) points = points
)
await client.close() return result
except Exception as error:
return result logger.error("Error uploading data points to Qdrant: %s", str(error))
raise error
finally:
await client.close()
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
client = self.get_qdrant_client() client = self.get_qdrant_client()
@ -122,7 +128,7 @@ class QDrantAdapter(VectorDBInterface):
collection_name: str, collection_name: str,
query_text: Optional[str] = None, query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None, query_vector: Optional[List[float]] = None,
limit: int = None, limit: int = 5,
with_vector: bool = False with_vector: bool = False
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:

View file

@ -1,10 +1,12 @@
import asyncio import asyncio
import logging
from typing import List, Optional from typing import List, Optional
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter")
class WeaviateAdapter(VectorDBInterface): class WeaviateAdapter(VectorDBInterface):
name = "Weaviate" name = "Weaviate"
@ -78,20 +80,25 @@ class WeaviateAdapter(VectorDBInterface):
vector = vector vector = vector
) )
data_points = list(map(convert_to_weaviate_data_points, data_points))
objects = list(map(convert_to_weaviate_data_points, data_points))
collection = self.get_collection(collection_name) collection = self.get_collection(collection_name)
with collection.batch.dynamic() as batch: try:
for data_row in objects: if len(data_points) > 1:
batch.add_object( return collection.data.insert_many(data_points)
properties = data_row.properties, else:
vector = data_row.vector return collection.data.insert(data_points[0])
) # with collection.batch.dynamic() as batch:
# for point in data_points:
return # batch.add_object(
# return self.get_collection(collection_name).data.insert_many(objects) # uuid = point.uuid,
# properties = point.properties,
# vector = point.vector
# )
except Exception as error:
logger.error("Error creating data points: %s", str(error))
raise error
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter from weaviate.classes.query import Filter

View file

@ -2,16 +2,16 @@ You are a top-tier algorithm designed for extracting information in structured f
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes. **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
**Edges** represent relationships between concepts. They're akin to Wikipedia links. **Edges** represent relationships between concepts. They're akin to Wikipedia links.
The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience. The aim is to achieve simplicity and clarity in the knowledge graph.
# 1. Labeling Nodes # 1. Labeling Nodes
**Consistency**: Ensure you use basic or elementary types for node labels. **Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person, always label it as **"Person"**. - For example, when you identify an entity representing a person, always label it as **"Person"**.
- Avoid using more specific terms like "Mathematician" or "Scientist". - Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
- Don't use too generic terms like "Entity". - Don't use too generic terms like "Entity".
**Node IDs**: Never utilize integers as node IDs. **Node IDs**: Never utilize integers as node IDs.
- Node IDs should be names or human-readable identifiers found in the text. - Node IDs should be names or human-readable identifiers found in the text.
# 2. Handling Numerical Data and Dates # 2. Handling Numerical Data and Dates
- For example, when you identify an entity representing a date, always label it as **"Date"**. - For example, when you identify an entity representing a date, make sure it has type **"Date"**.
- Extract the date in the format "YYYY-MM-DD" - Extract the date in the format "YYYY-MM-DD"
- If not possible to extract the whole date, extract month or year, or both if available. - If not possible to extract the whole date, extract month or year, or both if available.
- **Property Format**: Properties must be in a key-value format. - **Property Format**: Properties must be in a key-value format.
@ -23,4 +23,4 @@ The aim is to achieve simplicity and clarity in the knowledge graph, making it a
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID. always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
# 4. Strict Compliance # 4. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination""" Adhere to the rules strictly. Non-compliance will result in termination

View file

@ -29,7 +29,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
class Keyword(BaseModel): class Keyword(BaseModel):
id: str uuid: str
text: str text: str
chunk_id: str chunk_id: str
document_id: str document_id: str
@ -61,7 +61,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
DataPoint[Keyword]( DataPoint[Keyword](
id = str(classification_type_id), id = str(classification_type_id),
payload = Keyword.parse_obj({ payload = Keyword.parse_obj({
"id": str(classification_type_id), "uuid": str(classification_type_id),
"text": classification_type_label, "text": classification_type_label,
"chunk_id": str(data_chunk.chunk_id), "chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id), "document_id": str(data_chunk.document_id),
@ -100,7 +100,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
DataPoint[Keyword]( DataPoint[Keyword](
id = str(classification_subtype_id), id = str(classification_subtype_id),
payload = Keyword.parse_obj({ payload = Keyword.parse_obj({
"id": str(classification_subtype_id), "uuid": str(classification_subtype_id),
"text": classification_subtype_label, "text": classification_subtype_label,
"chunk_id": str(data_chunk.chunk_id), "chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id), "document_id": str(data_chunk.document_id),
@ -118,9 +118,9 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
) )
)) ))
edges.append(( edges.append((
str(classification_type_id),
str(classification_subtype_id), str(classification_subtype_id),
"contains", str(classification_type_id),
"is_subtype_of",
dict( dict(
relationship_name = "contains", relationship_name = "contains",
source_node_id = str(classification_type_id), source_node_id = str(classification_type_id),

View file

@ -1,25 +1,77 @@
import json
import asyncio import asyncio
from uuid import uuid5, NAMESPACE_OID
from datetime import datetime from datetime import datetime
from typing import Type from typing import Type
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
from ...processing.chunk_types.DocumentChunk import DocumentChunk from ...processing.chunk_types.DocumentChunk import DocumentChunk
from .extract_knowledge_graph import extract_content_graph from .extract_knowledge_graph import extract_content_graph
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): class EntityNode(BaseModel):
uuid: str
name: str
type: str
description: str
created_at: datetime
updated_at: datetime
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
chunk_graphs = await asyncio.gather( chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
) )
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes] has_collection = await vector_engine.has_collection(collection_name)
graph_type_node_ids = list(set(type_ids))
graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids) if not has_collection:
existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes} await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if type_node_id not in processed_nodes:
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
processed_nodes[type_node_id] = True
if entity_node_id not in processed_nodes:
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
processed_nodes[entity_node_id] = True
graph_node_edges = [
(edge.source_node_id, edge.target_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
existing_edges = await graph_engine.has_edges([
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
])
existing_edges_map = {}
existing_nodes_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
existing_nodes_map[edge[0]] = True
graph_nodes = [] graph_nodes = []
graph_edges = [] graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks): for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index] graph = chunk_graphs[chunk_index]
@ -28,90 +80,139 @@ async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model:
for node in graph.nodes: for node in graph.nodes:
node_id = generate_node_id(node.id) node_id = generate_node_id(node.id)
node_name = generate_name(node.name)
graph_nodes.append(( type_node_id = generate_node_id(node.type)
node_id, type_node_name = generate_name(node.type)
dict(
id = node_id, if node_id not in existing_nodes_map:
chunk_id = str(chunk.chunk_id), node_data = dict(
document_id = str(chunk.document_id), uuid = node_id,
name = node.name, name = node_name,
type = node.type.lower().capitalize(), type = node_name,
description = node.description, description = node.description,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
) )
))
graph_edges.append(( graph_nodes.append((
str(chunk.chunk_id), node_id,
node_id, dict(
"contains", **node_data,
dict( properties = json.dumps(node.properties),
relationship_name = "contains", )
source_node_id = str(chunk.chunk_id), ))
target_node_id = node_id,
),
))
type_node_id = generate_node_id(node.type) data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, node_id)),
payload = node_data,
embed_field = "name",
))
if type_node_id not in existing_type_nodes_map: existing_nodes_map[node_id] = True
node_name = node.type.lower().capitalize()
type_node = dict( edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
id = type_node_id,
name = node_name, if edge_key not in existing_edges_map:
type = node_name, graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains_entity",
dict(
relationship_name = "contains_entity",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
node_id,
type_node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_edges_map[edge_key] = True
if type_node_id not in existing_nodes_map:
type_node_data = dict(
uuid = type_node_id,
name = type_node_name,
type = type_node_id,
description = type_node_name,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
) )
graph_nodes.append((type_node_id, type_node)) graph_nodes.append((type_node_id, dict(
existing_type_nodes_map[type_node_id] = type_node **type_node_data,
properties = json.dumps(node.properties)
)))
graph_edges.append(( data_points.append(DataPoint[EntityNode](
str(chunk.chunk_id), id = str(uuid5(NAMESPACE_OID, type_node_id)),
type_node_id, payload = type_node_data,
"contains_entity_type", embed_field = "name",
dict( ))
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person" existing_nodes_map[type_node_id] = True
graph_edges.append((
type_node_id,
node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
# Add relationship that came from graphs. edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
for edge in graph.edges:
if edge_key not in existing_edges_map:
graph_edges.append(( graph_edges.append((
generate_node_id(edge.source_node_id), str(chunk.chunk_id),
generate_node_id(edge.target_node_id), type_node_id,
edge.relationship_name, "contains_entity_type",
dict( dict(
relationship_name = edge.relationship_name, relationship_name = "contains_entity_type",
source_node_id = generate_node_id(edge.source_node_id), source_node_id = str(chunk.chunk_id),
target_node_id = generate_node_id(edge.target_node_id), target_node_id = type_node_id,
), ),
)) ))
await graph_engine.add_nodes(graph_nodes) existing_edges_map[edge_key] = True
await graph_engine.add_edges(graph_edges) # Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_name(edge.relationship_name)
edge_key = source_node_id + target_node_id + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(
relationship_name = generate_name(edge.relationship_name),
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
properties = json.dumps(edge.properties),
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await vector_engine.create_data_points(collection_name, data_points)
if len(graph_nodes) > 0:
await graph_engine.add_nodes(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks return data_chunks
def generate_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
def generate_node_id(node_id: str) -> str: def generate_node_id(node_id: str) -> str:
return node_id.upper().replace(" ", "_").replace("'", "") return node_id.lower().replace(" ", "_").replace("'", "")

View file

@ -1,2 +1,3 @@
from .tasks.Task import Task
from .operations.run_tasks import run_tasks from .operations.run_tasks import run_tasks
from .operations.run_parallel import run_tasks_parallel from .operations.run_parallel import run_tasks_parallel

View file

@ -8,27 +8,29 @@ async def main():
for i in range(num): for i in range(num):
yield i + 1 yield i + 1
async def add_one(num): async def add_one(nums):
yield num + 1
async def multiply_by_two(nums):
for num in nums: for num in nums:
yield num * 2 yield num + 1
async def add_one_to_batched_data(num): async def multiply_by_two(num):
yield num * 2
async def add_one_single(num):
yield num + 1 yield num + 1
pipeline = run_tasks([ pipeline = run_tasks([
Task(number_generator, task_config = {"batch_size": 1}), Task(number_generator),
Task(add_one, task_config = {"batch_size": 5}), Task(add_one, task_config = {"batch_size": 5}),
Task(multiply_by_two, task_config = {"batch_size": 1}), Task(multiply_by_two, task_config = {"batch_size": 1}),
Task(add_one_to_batched_data), Task(add_one_single),
], 10) ], 10)
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
index = 0
async for result in pipeline: async for result in pipeline:
print("\n")
print(result) print(result)
print("\n") assert result == results[index]
index += 1
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View file

@ -0,0 +1,46 @@
import asyncio
from queue import Queue
from cognee.modules.pipelines.operations.run_tasks import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
async def pipeline(data_queue):
async def queue_consumer():
while not data_queue.is_closed:
if not data_queue.empty():
yield data_queue.get()
else:
await asyncio.sleep(0.3)
async def add_one(num):
yield num + 1
async def multiply_by_two(num):
yield num * 2
tasks_run = run_tasks([
Task(queue_consumer),
Task(add_one),
Task(multiply_by_two),
])
results = [2, 4, 6, 8, 10, 12, 14, 16, 18]
index = 0
async for result in tasks_run:
print(result)
assert result == results[index]
index += 1
async def main():
data_queue = Queue()
data_queue.is_closed = False
async def queue_producer():
for i in range(0, 10):
data_queue.put(i)
await asyncio.sleep(0.1)
data_queue.is_closed = True
await asyncio.gather(pipeline(data_queue), queue_producer())
if __name__ == "__main__":
asyncio.run(main())

View file

@ -4,29 +4,30 @@ from ..tasks.Task import Task
logger = logging.getLogger("run_tasks(tasks: [Task], data)") logger = logging.getLogger("run_tasks(tasks: [Task], data)")
async def run_tasks(tasks: [Task], data): async def run_tasks(tasks: [Task], data = None):
if len(tasks) == 0: if len(tasks) == 0:
yield data yield data
return return
args = [data] if data is not None else []
running_task = tasks[0] running_task = tasks[0]
batch_size = running_task.task_config["batch_size"]
leftover_tasks = tasks[1:] leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable): if inspect.isasyncgenfunction(running_task.executable):
logger.info(f"Running async generator task: `{running_task.executable.__name__}`") logger.info("Running async generator task: `%s`", running_task.executable.__name__)
try: try:
results = [] results = []
async_iterator = running_task.run(data) async_iterator = running_task.run(*args)
async for partial_result in async_iterator: async for partial_result in async_iterator:
results.append(partial_result) results.append(partial_result)
if len(results) == batch_size: if len(results) == next_task_batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results): async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
yield result yield result
results = [] results = []
@ -37,7 +38,7 @@ async def run_tasks(tasks: [Task], data):
results = [] results = []
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`") logger.info("Finished async generator task: `%s`", running_task.executable.__name__)
except Exception as error: except Exception as error:
logger.error( logger.error(
"Error occurred while running async generator task: `%s`\n%s\n", "Error occurred while running async generator task: `%s`\n%s\n",
@ -48,15 +49,15 @@ async def run_tasks(tasks: [Task], data):
raise error raise error
elif inspect.isgeneratorfunction(running_task.executable): elif inspect.isgeneratorfunction(running_task.executable):
logger.info(f"Running generator task: `{running_task.executable.__name__}`") logger.info("Running generator task: `%s`", running_task.executable.__name__)
try: try:
results = [] results = []
for partial_result in running_task.run(data): for partial_result in running_task.run(*args):
results.append(partial_result) results.append(partial_result)
if len(results) == batch_size: if len(results) == next_task_batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results): async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
yield result yield result
results = [] results = []
@ -67,7 +68,7 @@ async def run_tasks(tasks: [Task], data):
results = [] results = []
logger.info(f"Running generator task: `{running_task.executable.__name__}`") logger.info("Finished generator task: `%s`", running_task.executable.__name__)
except Exception as error: except Exception as error:
logger.error( logger.error(
"Error occurred while running generator task: `%s`\n%s\n", "Error occurred while running generator task: `%s`\n%s\n",
@ -78,13 +79,35 @@ async def run_tasks(tasks: [Task], data):
raise error raise error
elif inspect.iscoroutinefunction(running_task.executable): elif inspect.iscoroutinefunction(running_task.executable):
task_result = await running_task.run(data) logger.info("Running coroutine task: `%s`", running_task.executable.__name__)
try:
task_result = await running_task.run(*args)
async for result in run_tasks(leftover_tasks, task_result): async for result in run_tasks(leftover_tasks, task_result):
yield result yield result
logger.info("Finished coroutine task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running coroutine task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)
elif inspect.isfunction(running_task.executable): elif inspect.isfunction(running_task.executable):
task_result = running_task.run(data) logger.info("Running function task: `%s`", running_task.executable.__name__)
try:
task_result = running_task.run(*args)
async for result in run_tasks(leftover_tasks, task_result): async for result in run_tasks(leftover_tasks, task_result):
yield result yield result
logger.info("Finished function task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running function task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)

View file

@ -1,21 +1,18 @@
from typing import Union, Dict import asyncio
import networkx as nx
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]: async def search_adjacent(query: str) -> list[(str, str)]:
""" """
Find the neighbours of a given node in the graph and return their ids and descriptions. Find the neighbours of a given node in the graph and return their ids and descriptions.
Parameters: Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - query (str): The query string to filter nodes by.
- query (str): Unused in this implementation but could be used for future enhancements.
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
Returns: Returns:
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node. - list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
""" """
node_id = other_param.get("node_id") if other_param else query node_id = query
if node_id is None: if node_id is None:
return {} return {}
@ -23,16 +20,24 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param:
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id) exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["id"]) if exact_node is not None and "uuid" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
else: else:
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
collection_name = "classification" results = await asyncio.gather(
data_points = await vector_engine.search(collection_name, query_text = node_id, limit = 5) vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(data_points) == 0: if len(relevant_results) == 0:
return [] return []
neighbours = await graph_engine.get_neighbours(data_points[0].id) node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
neighbours = []
for neighbour_ids in node_neighbours:
neighbours.extend(neighbour_ids)
return [node["name"] for node in neighbours] return neighbours

View file

@ -1,18 +1,15 @@
import networkx as nx from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
from typing import Union
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config
async def search_cypher(query:str, graph: Union[nx.Graph, any]): async def search_cypher(query: str):
""" """
Use a Cypher query to search the graph and return the results. Use a Cypher query to search the graph and return the results.
""" """
graph_config = get_graph_config() graph_config = get_graph_config()
if graph_config.graph_database_provider == "neo4j": if graph_config.graph_database_provider == "neo4j":
result = await graph.run(query) graph_engine = await get_graph_engine()
result = await graph_engine.graph().run(query)
return result return result
else: else:
raise ValueError("Unsupported graph engine type.") raise ValueError("Unsupported search type for the used graph engine.")

View file

@ -1,22 +1,17 @@
from typing import Union, Dict
import networkx as nx
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
async def search_similarity(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]: async def search_similarity(query: str) -> list[str, str]:
""" """
Parameters: Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - query (str): The query string to filter nodes by.
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
Returns: Returns:
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'result' attributes. - list(chunk): A list of objects providing information about the chunks related to query.
""" """
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
similar_results = await vector_engine.search("chunks", query, limit = 5) similar_results = await vector_engine.search("chunks", query, limit = 5)
results = [{
"text": result.payload["text"], results = [result.payload for result in similar_results]
"chunk_id": result.payload["chunk_id"],
} for result in similar_results]
return results return results

View file

@ -1,24 +1,17 @@
from typing import Union, Dict
import networkx as nx
from cognee.shared.data_models import ChunkSummaries
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
async def search_summary(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]: async def search_summary(query: str) -> list:
""" """
Parameters: Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - query (str): The query string to filter summaries by.
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
- other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements.
Returns: Returns:
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes. - list[str, UUID]: A list of objects providing information about the summaries related to query.
""" """
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5) summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5)
summaries = [{
"text": summary.payload["text"], summaries = [summary.payload for summary in summaries_results]
"chunk_id": summary.payload["chunk_id"],
} for summary in summaries_results]
return summaries return summaries

View file

@ -1,21 +1,36 @@
import asyncio
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
async def search_traverse(query: str, graph): # graph must be there in order to be compatible with generic call async def search_traverse(query: str):
node_id = query
rules = set()
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
results = await vector_engine.search("classification", query_text = query, limit = 10) exact_node = await graph_engine.extract_node(node_id)
rules = [] if exact_node is not None and "uuid" in exact_node:
edges = await graph_engine.get_edges(exact_node["uuid"])
if len(results) > 0: for edge in edges:
for result in results: rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
graph_node_id = result.id else:
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
edges = await graph_engine.get_edges(graph_node_id) if len(relevant_results) > 0:
for result in relevant_results:
graph_node_id = result.id
for edge in edges: edges = await graph_engine.get_edges(graph_node_id)
rules.append(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
return rules for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
return list(rules)

5
cognee/pipelines.py Normal file
View file

@ -0,0 +1,5 @@
# Don't add any more code here, this file is used only for the purpose
# of enabling imports from `cognee.pipelines` module.
# `from cognee.pipelines import Task` for example.
from .modules.pipelines import *

View file

@ -10,12 +10,14 @@ class Node(BaseModel):
name: str name: str
type: str type: str
description: str description: str
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the node.")
class Edge(BaseModel): class Edge(BaseModel):
"""Edge in a knowledge graph.""" """Edge in a knowledge graph."""
source_node_id: str source_node_id: str
target_node_id: str target_node_id: str
relationship_name: str relationship_name: str
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the edge.")
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
"""Knowledge graph.""" """Knowledge graph."""

View file

@ -1,4 +1,3 @@
import os import os
import logging import logging
import pathlib import pathlib
@ -38,21 +37,32 @@ async def main():
await cognee.cognify([dataset_name], root_node_id = "ROOT") await cognee.cognify([dataset_name], root_node_id = "ROOT")
search_results = await cognee.search("TRAVERSE", { "query": "Text" }) from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist." assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n") print("\n\nQuery related summaries exist:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "Articles" }) search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "ROOT node has no neighbours." assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\nROOT node has neighbours.\n") print("\n\Large language model query found neighbours.\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")

View file

@ -33,21 +33,32 @@ async def main():
await cognee.cognify([dataset_name]) await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" }) from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist." assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n") print("\n\nQuery related summaries exist:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" }) search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "ROOT node has no neighbours." assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\nROOT node has neighbours.\n") print("\n\Large language model query found neighbours.\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")

View file

@ -34,21 +34,32 @@ async def main():
await cognee.cognify([dataset_name]) await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" }) from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist." assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n") print("\n\nQuery related summaries exist:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" }) search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "ROOT node has no neighbours." assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\nROOT node has neighbours.\n") print("\n\Large language model query found neighbours.\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")

View file

@ -32,21 +32,32 @@ async def main():
await cognee.cognify([dataset_name]) await cognee.cognify([dataset_name])
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" }) from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entities", "AI"))[0]
random_node_name = random_node.payload["name"]
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
assert len(search_results) != 0, "Query related summaries don't exist." assert len(search_results) != 0, "Query related summaries don't exist."
print("\n\nQuery related summaries exist:\n") print("\n\nQuery related summaries exist:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search("ADJACENT", { "query": "ROOT" }) search_results = await cognee.search("ADJACENT", { "query": random_node_name })
assert len(search_results) != 0, "ROOT node has no neighbours." assert len(search_results) != 0, "Large language model query found no neighbours."
print("\n\nROOT node has neighbours.\n") print("\n\Large language model query found neighbours.\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")