diff --git a/cognee/__init__.py b/cognee/__init__.py index 49a5d7f21..04c13eae1 100644 --- a/cognee/__init__.py +++ b/cognee/__init__.py @@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify from .api.v1.datasets.datasets import datasets from .api.v1.search.search import search, SearchType from .api.v1.prune import prune + +# Pipelines +from .modules import pipelines diff --git a/cognee/api/client.py b/cognee/api/client.py index fab464801..b9d8cd27f 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -8,7 +8,7 @@ import logging import sentry_sdk from typing import Dict, Any, List, Union, Optional, Literal from typing_extensions import Annotated -from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query +from fastapi import FastAPI, HTTPException, Form, UploadFile, Query from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index bcfd43273..589c5331d 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -1,18 +1,11 @@ import asyncio -import hashlib import logging import uuid from typing import Union -from fastapi_users import fastapi_users -from sqlalchemy.ext.asyncio import AsyncSession - from cognee.infrastructure.databases.graph import get_graph_config -from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \ - get_user_permissions, get_async_session_context, fast_api_users_init -# from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker -# from cognee.infrastructure.databases.relational.user_authentication.users import get_user_permissions, fastapi_users + get_async_session_context, fast_api_users_init from cognee.modules.cognify.config import get_cognify_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument @@ -62,8 +55,6 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session) if out: - - async with update_status_lock: task_status = get_task_status([dataset_name]) @@ -89,9 +80,9 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No root_node_id = "ROOT" tasks = [ - Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type - Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data - Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes + Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type + Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data + Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks Task( save_data_chunks, diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 2d0bdeb75..b0c64f9ed 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -11,7 +11,6 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent from cognee.modules.search.vector.search_traverse import search_traverse from cognee.modules.search.graph.search_summary import search_summary from cognee.modules.search.graph.search_similarity import search_similarity -from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.shared.utils import send_telemetry class SearchType(Enum): @@ -63,9 +62,6 @@ async def search(search_type: str, params: Dict[str, Any]) -> List: async def specific_search(query_params: List[SearchParameters]) -> List: - graph_client = await get_graph_engine() - graph = graph_client.graph - search_functions: Dict[SearchType, Callable] = { SearchType.ADJACENT: search_adjacent, SearchType.SUMMARY: search_summary, @@ -81,7 +77,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List: search_func = search_functions.get(search_param.search_type) if search_func: # Schedule the coroutine for execution and store the task - task = search_func(**search_param.params, graph = graph) + task = search_func(**search_param.params) search_tasks.append(task) # Use asyncio.gather to run all scheduled tasks concurrently @@ -92,7 +88,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List: send_telemetry("cognee.search") - return results + return results[0] if len(results) == 1 else results diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 51568973a..1c55f1e39 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -50,6 +50,17 @@ class Neo4jAdapter(GraphDBInterface): async def graph(self): return await self.get_session() + async def has_node(self, node_id: str) -> bool: + results = self.query( + """ + MATCH (n) + WHERE n.id = $node_id + RETURN COUNT(n) > 0 AS node_exists + """, + {"node_id": node_id} + ) + return results[0]["node_exists"] if len(results) > 0 else False + async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None): node_id = node_id.replace(":", "_") @@ -157,6 +168,39 @@ class Neo4jAdapter(GraphDBInterface): return await self.query(query, params) + async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool: + query = f""" + MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`) + RETURN COUNT(relationship) > 0 AS edge_exists + """ + + edge_exists = await self.query(query) + return edge_exists + + async def has_edges(self, edges): + query = """ + UNWIND $edges AS edge + MATCH (a)-[r]->(b) + WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name + RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists + """ + + try: + params = { + "edges": [{ + "from_node": edge[0], + "to_node": edge[1], + "relationship_name": edge[2], + } for edge in edges], + } + + results = await self.query(query, params) + return [result["edge_exists"] for result in results] + except Neo4jError as error: + logger.error("Neo4j query error: %s", error, exc_info = True) + raise error + + async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): serialized_properties = self.serialize_properties(edge_properties) from_node = from_node.replace(":", "_") @@ -198,8 +242,12 @@ class Neo4jAdapter(GraphDBInterface): }, } for edge in edges] - results = await self.query(query, dict(edges = edges)) - return results + try: + results = await self.query(query, dict(edges = edges)) + return results + except Neo4jError as error: + logger.error("Neo4j query error: %s", error, exc_info = True) + raise error async def get_edges(self, node_id: str): query = """ @@ -261,8 +309,9 @@ class Neo4jAdapter(GraphDBInterface): async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]: if edge_label is not None: query = """ - MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor) - RETURN predecessor.id AS id + MATCH (node)<-[r]-(predecessor) + WHERE node.id = $node_id AND type(r) = $edge_label + RETURN predecessor.id AS predecessor_id """ results = await self.query( @@ -273,11 +322,12 @@ class Neo4jAdapter(GraphDBInterface): ) ) - return [result["id"] for result in results] + return [result["predecessor_id"] for result in results] else: query = """ - MATCH (node:`{node_id}`)-[r]->(predecessor) - RETURN predecessor.id AS id + MATCH (node)<-[r]-(predecessor) + WHERE node.id = $node_id + RETURN predecessor.id AS predecessor_id """ results = await self.query( @@ -287,13 +337,14 @@ class Neo4jAdapter(GraphDBInterface): ) ) - return [result["id"] for result in results] + return [result["predecessor_id"] for result in results] async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]: if edge_label is not None: query = """ - MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor) - RETURN successor.id AS id + MATCH (node)-[r]->(successor) + WHERE node.id = $node_id AND type(r) = $edge_label + RETURN successor.id AS successor_id """ results = await self.query( @@ -304,11 +355,12 @@ class Neo4jAdapter(GraphDBInterface): ), ) - return [result["id"] for result in results] + return [result["successor_id"] for result in results] else: query = """ - MATCH (node:`{node_id}`)<-[r]-(successor) - RETURN successor.id AS id + MATCH (node)-[r]->(successor) + WHERE node.id = $node_id + RETURN successor.id AS successor_id """ results = await self.query( @@ -318,12 +370,12 @@ class Neo4jAdapter(GraphDBInterface): ) ) - return [result["id"] for result in results] + return [result["successor_id"] for result in results] async def get_neighbours(self, node_id: str) -> list[str]: - results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id)) + predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id)) - return [*results[0], *results[1]] + return [*predecessor_ids, *successor_ids] async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None: query = f""" diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index 848758867..1ee84eeb8 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -2,6 +2,7 @@ import os import json +import asyncio import logging from typing import Dict, Any, List import aiofiles @@ -25,6 +26,8 @@ class NetworkXAdapter(GraphDBInterface): self.filename = filename + async def has_node(self, node_id: str) -> bool: + return self.graph.has_node(node_id) async def add_node( self, @@ -45,6 +48,18 @@ class NetworkXAdapter(GraphDBInterface): async def get_graph(self): return self.graph + async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool: + return self.graph.has_edge(from_node, to_node, key = edge_label) + + async def has_edges(self, edges): + result = [] + + for (from_node, to_node, edge_label) in edges: + if await self.has_edge(from_node, to_node, edge_label): + result.append((from_node, to_node, edge_label)) + + return result + async def add_edge( self, from_node: str, @@ -154,7 +169,12 @@ class NetworkXAdapter(GraphDBInterface): if not self.graph.has_node(node_id): return [] - neighbour_ids = list(self.graph.neighbors(node_id)) + predecessor_ids, successor_ids = await asyncio.gather( + self.get_predecessor_ids(node_id), + self.get_successor_ids(node_id), + ) + + neighbour_ids = predecessor_ids + successor_ids if len(neighbour_ids) == 0: return [] diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 9381639f6..3bb47fcc0 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -101,7 +101,7 @@ class LanceDBAdapter(VectorDBInterface): return [ScoredResult( id = result["id"], payload = result["payload"], - score = 1, + score = 0, ) for result in results.to_dict("index").values()] async def search( @@ -109,7 +109,7 @@ class LanceDBAdapter(VectorDBInterface): collection_name: str, query_text: str = None, query_vector: List[float] = None, - limit: int = 10, + limit: int = 5, with_vector: bool = False, ): if query_text is None and query_vector is None: @@ -123,11 +123,25 @@ class LanceDBAdapter(VectorDBInterface): results = await collection.vector_search(query_vector).limit(limit).to_pandas() + result_values = list(results.to_dict("index").values()) + + min_value = 100 + max_value = 0 + + for result in result_values: + value = float(result["_distance"]) + if value > max_value: + max_value = value + if value < min_value: + min_value = value + + normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] + return [ScoredResult( id = str(result["id"]), payload = result["payload"], - score = float(result["_distance"]), - ) for result in results.to_dict("index").values()] + score = normalized_values[value_index], + ) for value_index, result in enumerate(result_values)] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/models/ScoredResult.py b/cognee/infrastructure/databases/vector/models/ScoredResult.py index 4ff287cd6..fcecbbe79 100644 --- a/cognee/infrastructure/databases/vector/models/ScoredResult.py +++ b/cognee/infrastructure/databases/vector/models/ScoredResult.py @@ -1,8 +1,7 @@ -from uuid import UUID from typing import Any, Dict from pydantic import BaseModel class ScoredResult(BaseModel): id: str - score: float + score: float # Lower score is better payload: Dict[str, Any] diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index d2c02a53b..7e4b39f10 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -1,9 +1,12 @@ +import logging from typing import List, Dict, Optional from qdrant_client import AsyncQdrantClient, models from ..vector_db_interface import VectorDBInterface from ..models.DataPoint import DataPoint from ..embeddings.EmbeddingEngine import EmbeddingEngine +logger = logging.getLogger("QDrantAdapter") + # class CollectionConfig(BaseModel, extra = "forbid"): # vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" ) # hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration") @@ -102,14 +105,17 @@ class QDrantAdapter(VectorDBInterface): points = [convert_to_qdrant_point(point) for point in data_points] - result = await client.upload_points( - collection_name = collection_name, - points = points - ) - - await client.close() - - return result + try: + result = await client.upload_points( + collection_name = collection_name, + points = points + ) + return result + except Exception as error: + logger.error("Error uploading data points to Qdrant: %s", str(error)) + raise error + finally: + await client.close() async def retrieve(self, collection_name: str, data_point_ids: list[str]): client = self.get_qdrant_client() @@ -122,7 +128,7 @@ class QDrantAdapter(VectorDBInterface): collection_name: str, query_text: Optional[str] = None, query_vector: Optional[List[float]] = None, - limit: int = None, + limit: int = 5, with_vector: bool = False ): if query_text is None and query_vector is None: diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index 9a3ed9dd0..e569c5a49 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -1,10 +1,12 @@ import asyncio +import logging from typing import List, Optional from ..vector_db_interface import VectorDBInterface from ..models.DataPoint import DataPoint from ..models.ScoredResult import ScoredResult from ..embeddings.EmbeddingEngine import EmbeddingEngine +logger = logging.getLogger("WeaviateAdapter") class WeaviateAdapter(VectorDBInterface): name = "Weaviate" @@ -78,20 +80,25 @@ class WeaviateAdapter(VectorDBInterface): vector = vector ) - - objects = list(map(convert_to_weaviate_data_points, data_points)) + data_points = list(map(convert_to_weaviate_data_points, data_points)) collection = self.get_collection(collection_name) - with collection.batch.dynamic() as batch: - for data_row in objects: - batch.add_object( - properties = data_row.properties, - vector = data_row.vector - ) - - return - # return self.get_collection(collection_name).data.insert_many(objects) + try: + if len(data_points) > 1: + return collection.data.insert_many(data_points) + else: + return collection.data.insert(data_points[0]) + # with collection.batch.dynamic() as batch: + # for point in data_points: + # batch.add_object( + # uuid = point.uuid, + # properties = point.properties, + # vector = point.vector + # ) + except Exception as error: + logger.error("Error creating data points: %s", str(error)) + raise error async def retrieve(self, collection_name: str, data_point_ids: list[str]): from weaviate.classes.query import Filter diff --git a/cognee/infrastructure/llm/prompts/generate_graph_prompt.txt b/cognee/infrastructure/llm/prompts/generate_graph_prompt.txt index a1113bee6..6392cdc33 100644 --- a/cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +++ b/cognee/infrastructure/llm/prompts/generate_graph_prompt.txt @@ -2,16 +2,16 @@ You are a top-tier algorithm designed for extracting information in structured f **Nodes** represent entities and concepts. They're akin to Wikipedia nodes. **Edges** represent relationships between concepts. They're akin to Wikipedia links. -The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience. +The aim is to achieve simplicity and clarity in the knowledge graph. # 1. Labeling Nodes **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **"Person"**. - - Avoid using more specific terms like "Mathematician" or "Scientist". + - Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property. - Don't use too generic terms like "Entity". **Node IDs**: Never utilize integers as node IDs. - Node IDs should be names or human-readable identifiers found in the text. # 2. Handling Numerical Data and Dates - - For example, when you identify an entity representing a date, always label it as **"Date"**. + - For example, when you identify an entity representing a date, make sure it has type **"Date"**. - Extract the date in the format "YYYY-MM-DD" - If not possible to extract the whole date, extract month or year, or both if available. - **Property Format**: Properties must be in a key-value format. @@ -23,4 +23,4 @@ The aim is to achieve simplicity and clarity in the knowledge graph, making it a always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. # 4. Strict Compliance -Adhere to the rules strictly. Non-compliance will result in termination""" +Adhere to the rules strictly. Non-compliance will result in termination diff --git a/cognee/modules/classification/classify_text_chunks.py b/cognee/modules/classification/classify_text_chunks.py index 97dd90e56..5546b41ea 100644 --- a/cognee/modules/classification/classify_text_chunks.py +++ b/cognee/modules/classification/classify_text_chunks.py @@ -29,7 +29,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_ vector_engine = get_vector_engine() class Keyword(BaseModel): - id: str + uuid: str text: str chunk_id: str document_id: str @@ -61,7 +61,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_ DataPoint[Keyword]( id = str(classification_type_id), payload = Keyword.parse_obj({ - "id": str(classification_type_id), + "uuid": str(classification_type_id), "text": classification_type_label, "chunk_id": str(data_chunk.chunk_id), "document_id": str(data_chunk.document_id), @@ -100,7 +100,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_ DataPoint[Keyword]( id = str(classification_subtype_id), payload = Keyword.parse_obj({ - "id": str(classification_subtype_id), + "uuid": str(classification_subtype_id), "text": classification_subtype_label, "chunk_id": str(data_chunk.chunk_id), "document_id": str(data_chunk.document_id), @@ -118,9 +118,9 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_ ) )) edges.append(( - str(classification_type_id), str(classification_subtype_id), - "contains", + str(classification_type_id), + "is_subtype_of", dict( relationship_name = "contains", source_node_id = str(classification_type_id), diff --git a/cognee/modules/data/extraction/knowledge_graph/expand_knowledge_graph.py b/cognee/modules/data/extraction/knowledge_graph/expand_knowledge_graph.py index d6475512d..3735b41b9 100644 --- a/cognee/modules/data/extraction/knowledge_graph/expand_knowledge_graph.py +++ b/cognee/modules/data/extraction/knowledge_graph/expand_knowledge_graph.py @@ -1,25 +1,77 @@ +import json import asyncio +from uuid import uuid5, NAMESPACE_OID from datetime import datetime from typing import Type from pydantic import BaseModel from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine from ...processing.chunk_types.DocumentChunk import DocumentChunk from .extract_knowledge_graph import extract_content_graph -async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): +class EntityNode(BaseModel): + uuid: str + name: str + type: str + description: str + created_at: datetime + updated_at: datetime + +async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str): chunk_graphs = await asyncio.gather( *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] ) + vector_engine = get_vector_engine() graph_engine = await get_graph_engine() - type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes] - graph_type_node_ids = list(set(type_ids)) - graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids) - existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes} + has_collection = await vector_engine.has_collection(collection_name) + + if not has_collection: + await vector_engine.create_collection(collection_name, payload_schema = EntityNode) + + processed_nodes = {} + type_node_edges = [] + entity_node_edges = [] + type_entity_edges = [] + + for (chunk_index, chunk) in enumerate(data_chunks): + chunk_graph = chunk_graphs[chunk_index] + for node in chunk_graph.nodes: + type_node_id = generate_node_id(node.type) + entity_node_id = generate_node_id(node.id) + + if type_node_id not in processed_nodes: + type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type")) + processed_nodes[type_node_id] = True + + if entity_node_id not in processed_nodes: + entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity")) + type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type")) + processed_nodes[entity_node_id] = True + + graph_node_edges = [ + (edge.source_node_id, edge.target_node_id, edge.relationship_name) \ + for edge in chunk_graph.edges + ] + + existing_edges = await graph_engine.has_edges([ + *type_node_edges, + *entity_node_edges, + *type_entity_edges, + *graph_node_edges, + ]) + + existing_edges_map = {} + existing_nodes_map = {} + + for edge in existing_edges: + existing_edges_map[edge[0] + edge[1] + edge[2]] = True + existing_nodes_map[edge[0]] = True graph_nodes = [] graph_edges = [] + data_points = [] for (chunk_index, chunk) in enumerate(data_chunks): graph = chunk_graphs[chunk_index] @@ -28,90 +80,139 @@ async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: for node in graph.nodes: node_id = generate_node_id(node.id) + node_name = generate_name(node.name) - graph_nodes.append(( - node_id, - dict( - id = node_id, - chunk_id = str(chunk.chunk_id), - document_id = str(chunk.document_id), - name = node.name, - type = node.type.lower().capitalize(), + type_node_id = generate_node_id(node.type) + type_node_name = generate_name(node.type) + + if node_id not in existing_nodes_map: + node_data = dict( + uuid = node_id, + name = node_name, + type = node_name, description = node.description, created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ) - )) - graph_edges.append(( - str(chunk.chunk_id), - node_id, - "contains", - dict( - relationship_name = "contains", - source_node_id = str(chunk.chunk_id), - target_node_id = node_id, - ), - )) + graph_nodes.append(( + node_id, + dict( + **node_data, + properties = json.dumps(node.properties), + ) + )) - type_node_id = generate_node_id(node.type) + data_points.append(DataPoint[EntityNode]( + id = str(uuid5(NAMESPACE_OID, node_id)), + payload = node_data, + embed_field = "name", + )) - if type_node_id not in existing_type_nodes_map: - node_name = node.type.lower().capitalize() + existing_nodes_map[node_id] = True - type_node = dict( - id = type_node_id, - name = node_name, - type = node_name, + edge_key = str(chunk.chunk_id) + node_id + "contains_entity" + + if edge_key not in existing_edges_map: + graph_edges.append(( + str(chunk.chunk_id), + node_id, + "contains_entity", + dict( + relationship_name = "contains_entity", + source_node_id = str(chunk.chunk_id), + target_node_id = node_id, + ), + )) + + # Add relationship between entity type and entity itself: "Jake is Person" + graph_edges.append(( + node_id, + type_node_id, + "is_entity_type", + dict( + relationship_name = "is_entity_type", + source_node_id = type_node_id, + target_node_id = node_id, + ), + )) + + existing_edges_map[edge_key] = True + + if type_node_id not in existing_nodes_map: + type_node_data = dict( + uuid = type_node_id, + name = type_node_name, + type = type_node_id, + description = type_node_name, created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ) - graph_nodes.append((type_node_id, type_node)) - existing_type_nodes_map[type_node_id] = type_node + graph_nodes.append((type_node_id, dict( + **type_node_data, + properties = json.dumps(node.properties) + ))) - graph_edges.append(( - str(chunk.chunk_id), - type_node_id, - "contains_entity_type", - dict( - relationship_name = "contains_entity_type", - source_node_id = str(chunk.chunk_id), - target_node_id = type_node_id, - ), - )) + data_points.append(DataPoint[EntityNode]( + id = str(uuid5(NAMESPACE_OID, type_node_id)), + payload = type_node_data, + embed_field = "name", + )) - # Add relationship between entity type and entity itself: "Jake is Person" - graph_edges.append(( - type_node_id, - node_id, - "is_entity_type", - dict( - relationship_name = "is_entity_type", - source_node_id = type_node_id, - target_node_id = node_id, - ), - )) + existing_nodes_map[type_node_id] = True - # Add relationship that came from graphs. - for edge in graph.edges: + edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type" + + if edge_key not in existing_edges_map: graph_edges.append(( - generate_node_id(edge.source_node_id), - generate_node_id(edge.target_node_id), - edge.relationship_name, + str(chunk.chunk_id), + type_node_id, + "contains_entity_type", dict( - relationship_name = edge.relationship_name, - source_node_id = generate_node_id(edge.source_node_id), - target_node_id = generate_node_id(edge.target_node_id), + relationship_name = "contains_entity_type", + source_node_id = str(chunk.chunk_id), + target_node_id = type_node_id, ), )) - await graph_engine.add_nodes(graph_nodes) + existing_edges_map[edge_key] = True - await graph_engine.add_edges(graph_edges) + # Add relationship that came from graphs. + for edge in graph.edges: + source_node_id = generate_node_id(edge.source_node_id) + target_node_id = generate_node_id(edge.target_node_id) + relationship_name = generate_name(edge.relationship_name) + edge_key = source_node_id + target_node_id + relationship_name + + if edge_key not in existing_edges_map: + graph_edges.append(( + generate_node_id(edge.source_node_id), + generate_node_id(edge.target_node_id), + edge.relationship_name, + dict( + relationship_name = generate_name(edge.relationship_name), + source_node_id = generate_node_id(edge.source_node_id), + target_node_id = generate_node_id(edge.target_node_id), + properties = json.dumps(edge.properties), + ), + )) + existing_edges_map[edge_key] = True + + if len(data_points) > 0: + await vector_engine.create_data_points(collection_name, data_points) + + if len(graph_nodes) > 0: + await graph_engine.add_nodes(graph_nodes) + + if len(graph_edges) > 0: + await graph_engine.add_edges(graph_edges) return data_chunks +def generate_name(name: str) -> str: + return name.lower().replace(" ", "_").replace("'", "") + def generate_node_id(node_id: str) -> str: - return node_id.upper().replace(" ", "_").replace("'", "") + return node_id.lower().replace(" ", "_").replace("'", "") diff --git a/cognee/modules/pipelines/__init__.py b/cognee/modules/pipelines/__init__.py index cba929c36..5005c25f0 100644 --- a/cognee/modules/pipelines/__init__.py +++ b/cognee/modules/pipelines/__init__.py @@ -1,2 +1,3 @@ +from .tasks.Task import Task from .operations.run_tasks import run_tasks from .operations.run_parallel import run_tasks_parallel diff --git a/cognee/modules/pipelines/operations/__tests__/__index__.py b/cognee/modules/pipelines/operations/__tests__/__init__.py similarity index 100% rename from cognee/modules/pipelines/operations/__tests__/__index__.py rename to cognee/modules/pipelines/operations/__tests__/__init__.py diff --git a/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py b/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py index 387b97274..2fef802fd 100644 --- a/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py +++ b/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py @@ -8,27 +8,29 @@ async def main(): for i in range(num): yield i + 1 - async def add_one(num): - yield num + 1 - - async def multiply_by_two(nums): + async def add_one(nums): for num in nums: - yield num * 2 + yield num + 1 - async def add_one_to_batched_data(num): + async def multiply_by_two(num): + yield num * 2 + + async def add_one_single(num): yield num + 1 pipeline = run_tasks([ - Task(number_generator, task_config = {"batch_size": 1}), + Task(number_generator), Task(add_one, task_config = {"batch_size": 5}), Task(multiply_by_two, task_config = {"batch_size": 1}), - Task(add_one_to_batched_data), + Task(add_one_single), ], 10) + results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23] + index = 0 async for result in pipeline: - print("\n") print(result) - print("\n") + assert result == results[index] + index += 1 if __name__ == "__main__": asyncio.run(main()) diff --git a/cognee/modules/pipelines/operations/__tests__/run_tasks_from_queue.test.py b/cognee/modules/pipelines/operations/__tests__/run_tasks_from_queue.test.py new file mode 100644 index 000000000..387d22ce6 --- /dev/null +++ b/cognee/modules/pipelines/operations/__tests__/run_tasks_from_queue.test.py @@ -0,0 +1,46 @@ +import asyncio +from queue import Queue +from cognee.modules.pipelines.operations.run_tasks import run_tasks +from cognee.modules.pipelines.tasks.Task import Task + +async def pipeline(data_queue): + async def queue_consumer(): + while not data_queue.is_closed: + if not data_queue.empty(): + yield data_queue.get() + else: + await asyncio.sleep(0.3) + + async def add_one(num): + yield num + 1 + + async def multiply_by_two(num): + yield num * 2 + + tasks_run = run_tasks([ + Task(queue_consumer), + Task(add_one), + Task(multiply_by_two), + ]) + + results = [2, 4, 6, 8, 10, 12, 14, 16, 18] + index = 0 + async for result in tasks_run: + print(result) + assert result == results[index] + index += 1 + +async def main(): + data_queue = Queue() + data_queue.is_closed = False + + async def queue_producer(): + for i in range(0, 10): + data_queue.put(i) + await asyncio.sleep(0.1) + data_queue.is_closed = True + + await asyncio.gather(pipeline(data_queue), queue_producer()) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 1000743da..f90eece41 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -4,29 +4,30 @@ from ..tasks.Task import Task logger = logging.getLogger("run_tasks(tasks: [Task], data)") -async def run_tasks(tasks: [Task], data): +async def run_tasks(tasks: [Task], data = None): if len(tasks) == 0: yield data return + args = [data] if data is not None else [] + running_task = tasks[0] - batch_size = running_task.task_config["batch_size"] leftover_tasks = tasks[1:] next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None - # next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 + next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 if inspect.isasyncgenfunction(running_task.executable): - logger.info(f"Running async generator task: `{running_task.executable.__name__}`") + logger.info("Running async generator task: `%s`", running_task.executable.__name__) try: results = [] - async_iterator = running_task.run(data) + async_iterator = running_task.run(*args) async for partial_result in async_iterator: results.append(partial_result) - if len(results) == batch_size: - async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results): + if len(results) == next_task_batch_size: + async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results): yield result results = [] @@ -37,7 +38,7 @@ async def run_tasks(tasks: [Task], data): results = [] - logger.info(f"Finished async generator task: `{running_task.executable.__name__}`") + logger.info("Finished async generator task: `%s`", running_task.executable.__name__) except Exception as error: logger.error( "Error occurred while running async generator task: `%s`\n%s\n", @@ -48,15 +49,15 @@ async def run_tasks(tasks: [Task], data): raise error elif inspect.isgeneratorfunction(running_task.executable): - logger.info(f"Running generator task: `{running_task.executable.__name__}`") + logger.info("Running generator task: `%s`", running_task.executable.__name__) try: results = [] - for partial_result in running_task.run(data): + for partial_result in running_task.run(*args): results.append(partial_result) - if len(results) == batch_size: - async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results): + if len(results) == next_task_batch_size: + async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results): yield result results = [] @@ -67,7 +68,7 @@ async def run_tasks(tasks: [Task], data): results = [] - logger.info(f"Running generator task: `{running_task.executable.__name__}`") + logger.info("Finished generator task: `%s`", running_task.executable.__name__) except Exception as error: logger.error( "Error occurred while running generator task: `%s`\n%s\n", @@ -78,13 +79,35 @@ async def run_tasks(tasks: [Task], data): raise error elif inspect.iscoroutinefunction(running_task.executable): - task_result = await running_task.run(data) + logger.info("Running coroutine task: `%s`", running_task.executable.__name__) + try: + task_result = await running_task.run(*args) - async for result in run_tasks(leftover_tasks, task_result): - yield result + async for result in run_tasks(leftover_tasks, task_result): + yield result + logger.info("Finished coroutine task: `%s`", running_task.executable.__name__) + except Exception as error: + logger.error( + "Error occurred while running coroutine task: `%s`\n%s\n", + running_task.executable.__name__, + str(error), + exc_info = True, + ) + elif inspect.isfunction(running_task.executable): - task_result = running_task.run(data) + logger.info("Running function task: `%s`", running_task.executable.__name__) + try: + task_result = running_task.run(*args) - async for result in run_tasks(leftover_tasks, task_result): - yield result + async for result in run_tasks(leftover_tasks, task_result): + yield result + + logger.info("Finished function task: `%s`", running_task.executable.__name__) + except Exception as error: + logger.error( + "Error occurred while running function task: `%s`\n%s\n", + running_task.executable.__name__, + str(error), + exc_info = True, + ) diff --git a/cognee/modules/search/graph/search_adjacent.py b/cognee/modules/search/graph/search_adjacent.py index a8eb82355..7295ebe76 100644 --- a/cognee/modules/search/graph/search_adjacent.py +++ b/cognee/modules/search/graph/search_adjacent.py @@ -1,21 +1,18 @@ -from typing import Union, Dict -import networkx as nx +import asyncio from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine -async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]: +async def search_adjacent(query: str) -> list[(str, str)]: """ Find the neighbours of a given node in the graph and return their ids and descriptions. Parameters: - - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - - query (str): Unused in this implementation but could be used for future enhancements. - - other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node. + - query (str): The query string to filter nodes by. Returns: - - Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node. + - list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node. """ - node_id = other_param.get("node_id") if other_param else query + node_id = query if node_id is None: return {} @@ -23,16 +20,24 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: graph_engine = await get_graph_engine() exact_node = await graph_engine.extract_node(node_id) - if exact_node is not None and "id" in exact_node: - neighbours = await graph_engine.get_neighbours(exact_node["id"]) + + if exact_node is not None and "uuid" in exact_node: + neighbours = await graph_engine.get_neighbours(exact_node["uuid"]) else: vector_engine = get_vector_engine() - collection_name = "classification" - data_points = await vector_engine.search(collection_name, query_text = node_id, limit = 5) + results = await asyncio.gather( + vector_engine.search("entities", query_text = query, limit = 10), + vector_engine.search("classification", query_text = query, limit = 10), + ) + results = [*results[0], *results[1]] + relevant_results = [result for result in results if result.score < 0.5][:5] - if len(data_points) == 0: + if len(relevant_results) == 0: return [] - neighbours = await graph_engine.get_neighbours(data_points[0].id) + node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results]) + neighbours = [] + for neighbour_ids in node_neighbours: + neighbours.extend(neighbour_ids) - return [node["name"] for node in neighbours] + return neighbours diff --git a/cognee/modules/search/graph/search_cypher.py b/cognee/modules/search/graph/search_cypher.py index 10078db3e..39a09542a 100644 --- a/cognee/modules/search/graph/search_cypher.py +++ b/cognee/modules/search/graph/search_cypher.py @@ -1,18 +1,15 @@ -import networkx as nx -from typing import Union -from cognee.shared.data_models import GraphDBType -from cognee.infrastructure.databases.graph.config import get_graph_config +from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config -async def search_cypher(query:str, graph: Union[nx.Graph, any]): +async def search_cypher(query: str): """ Use a Cypher query to search the graph and return the results. """ graph_config = get_graph_config() if graph_config.graph_database_provider == "neo4j": - result = await graph.run(query) + graph_engine = await get_graph_engine() + result = await graph_engine.graph().run(query) return result - else: - raise ValueError("Unsupported graph engine type.") + raise ValueError("Unsupported search type for the used graph engine.") diff --git a/cognee/modules/search/graph/search_similarity.py b/cognee/modules/search/graph/search_similarity.py index d06bc897e..24bb17ba4 100644 --- a/cognee/modules/search/graph/search_similarity.py +++ b/cognee/modules/search/graph/search_similarity.py @@ -1,22 +1,17 @@ -from typing import Union, Dict -import networkx as nx from cognee.infrastructure.databases.vector import get_vector_engine -async def search_similarity(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]: +async def search_similarity(query: str) -> list[str, str]: """ Parameters: - - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - - query (str): The query string to filter nodes by, e.g., 'SUMMARY'. + - query (str): The query string to filter nodes by. Returns: - - Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'result' attributes. + - list(chunk): A list of objects providing information about the chunks related to query. """ vector_engine = get_vector_engine() similar_results = await vector_engine.search("chunks", query, limit = 5) - results = [{ - "text": result.payload["text"], - "chunk_id": result.payload["chunk_id"], - } for result in similar_results] + + results = [result.payload for result in similar_results] return results diff --git a/cognee/modules/search/graph/search_summary.py b/cognee/modules/search/graph/search_summary.py index a46e35968..b576c0bd4 100644 --- a/cognee/modules/search/graph/search_summary.py +++ b/cognee/modules/search/graph/search_summary.py @@ -1,24 +1,17 @@ -from typing import Union, Dict -import networkx as nx -from cognee.shared.data_models import ChunkSummaries from cognee.infrastructure.databases.vector import get_vector_engine -async def search_summary(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]: +async def search_summary(query: str) -> list: """ Parameters: - - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - - query (str): The query string to filter nodes by, e.g., 'SUMMARY'. - - other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements. + - query (str): The query string to filter summaries by. Returns: - - Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes. + - list[str, UUID]: A list of objects providing information about the summaries related to query. """ vector_engine = get_vector_engine() summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5) - summaries = [{ - "text": summary.payload["text"], - "chunk_id": summary.payload["chunk_id"], - } for summary in summaries_results] + + summaries = [summary.payload for summary in summaries_results] return summaries diff --git a/cognee/modules/search/vector/search_traverse.py b/cognee/modules/search/vector/search_traverse.py index 4d64771f2..5c1d07924 100644 --- a/cognee/modules/search/vector/search_traverse.py +++ b/cognee/modules/search/vector/search_traverse.py @@ -1,21 +1,36 @@ +import asyncio from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine -async def search_traverse(query: str, graph): # graph must be there in order to be compatible with generic call +async def search_traverse(query: str): + node_id = query + rules = set() + graph_engine = await get_graph_engine() vector_engine = get_vector_engine() - results = await vector_engine.search("classification", query_text = query, limit = 10) + exact_node = await graph_engine.extract_node(node_id) - rules = [] + if exact_node is not None and "uuid" in exact_node: + edges = await graph_engine.get_edges(exact_node["uuid"]) - if len(results) > 0: - for result in results: - graph_node_id = result.id + for edge in edges: + rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}") + else: + results = await asyncio.gather( + vector_engine.search("entities", query_text = query, limit = 10), + vector_engine.search("classification", query_text = query, limit = 10), + ) + results = [*results[0], *results[1]] + relevant_results = [result for result in results if result.score < 0.5][:5] - edges = await graph_engine.get_edges(graph_node_id) + if len(relevant_results) > 0: + for result in relevant_results: + graph_node_id = result.id - for edge in edges: - rules.append(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}") + edges = await graph_engine.get_edges(graph_node_id) - return rules + for edge in edges: + rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}") + + return list(rules) diff --git a/cognee/pipelines.py b/cognee/pipelines.py new file mode 100644 index 000000000..7e5e791d4 --- /dev/null +++ b/cognee/pipelines.py @@ -0,0 +1,5 @@ +# Don't add any more code here, this file is used only for the purpose +# of enabling imports from `cognee.pipelines` module. +# `from cognee.pipelines import Task` for example. + +from .modules.pipelines import * diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index 98ea34290..0c54436d9 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -10,12 +10,14 @@ class Node(BaseModel): name: str type: str description: str + properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the node.") class Edge(BaseModel): """Edge in a knowledge graph.""" source_node_id: str target_node_id: str relationship_name: str + properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the edge.") class KnowledgeGraph(BaseModel): """Knowledge graph.""" diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 2cb32db14..9580b5f5d 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -1,4 +1,3 @@ - import os import logging import pathlib @@ -38,21 +37,32 @@ async def main(): await cognee.cognify([dataset_name], root_node_id = "ROOT") - search_results = await cognee.search("TRAVERSE", { "query": "Text" }) + from cognee.infrastructure.databases.vector import get_vector_engine + vector_engine = get_vector_engine() + random_node = (await vector_engine.search("entities", "AI"))[0] + random_node_name = random_node.payload["name"] + + search_results = await cognee.search("SIMILARITY", { "query": random_node_name }) assert len(search_results) != 0, "The search results list is empty." print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) + search_results = await cognee.search("TRAVERSE", { "query": random_node_name }) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted sentences are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search("SUMMARY", { "query": random_node_name }) assert len(search_results) != 0, "Query related summaries don't exist." print("\n\nQuery related summaries exist:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("ADJACENT", { "query": "Articles" }) - assert len(search_results) != 0, "ROOT node has no neighbours." - print("\n\nROOT node has neighbours.\n") + search_results = await cognee.search("ADJACENT", { "query": random_node_name }) + assert len(search_results) != 0, "Large language model query found no neighbours." + print("\n\Large language model query found neighbours.\n") for result in search_results: print(f"{result}\n") diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index c4a871abe..ddf06e8e3 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -33,21 +33,32 @@ async def main(): await cognee.cognify([dataset_name]) - search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" }) + from cognee.infrastructure.databases.vector import get_vector_engine + vector_engine = get_vector_engine() + random_node = (await vector_engine.search("entities", "AI"))[0] + random_node_name = random_node.payload["name"] + + search_results = await cognee.search("SIMILARITY", { "query": random_node_name }) assert len(search_results) != 0, "The search results list is empty." print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) + search_results = await cognee.search("TRAVERSE", { "query": random_node_name }) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted sentences are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search("SUMMARY", { "query": random_node_name }) assert len(search_results) != 0, "Query related summaries don't exist." print("\n\nQuery related summaries exist:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("ADJACENT", { "query": "ROOT" }) - assert len(search_results) != 0, "ROOT node has no neighbours." - print("\n\nROOT node has neighbours.\n") + search_results = await cognee.search("ADJACENT", { "query": random_node_name }) + assert len(search_results) != 0, "Large language model query found no neighbours." + print("\n\Large language model query found neighbours.\n") for result in search_results: print(f"{result}\n") diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index 9ee7a0d7d..4ae2ba9aa 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -34,21 +34,32 @@ async def main(): await cognee.cognify([dataset_name]) - search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" }) + from cognee.infrastructure.databases.vector import get_vector_engine + vector_engine = get_vector_engine() + random_node = (await vector_engine.search("entities", "AI"))[0] + random_node_name = random_node.payload["name"] + + search_results = await cognee.search("SIMILARITY", { "query": random_node_name }) assert len(search_results) != 0, "The search results list is empty." print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) + search_results = await cognee.search("TRAVERSE", { "query": random_node_name }) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted sentences are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search("SUMMARY", { "query": random_node_name }) assert len(search_results) != 0, "Query related summaries don't exist." print("\n\nQuery related summaries exist:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("ADJACENT", { "query": "ROOT" }) - assert len(search_results) != 0, "ROOT node has no neighbours." - print("\n\nROOT node has neighbours.\n") + search_results = await cognee.search("ADJACENT", { "query": random_node_name }) + assert len(search_results) != 0, "Large language model query found no neighbours." + print("\n\Large language model query found neighbours.\n") for result in search_results: print(f"{result}\n") diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index 09d8ee62b..1b32b0085 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -32,21 +32,32 @@ async def main(): await cognee.cognify([dataset_name]) - search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" }) + from cognee.infrastructure.databases.vector import get_vector_engine + vector_engine = get_vector_engine() + random_node = (await vector_engine.search("entities", "AI"))[0] + random_node_name = random_node.payload["name"] + + search_results = await cognee.search("SIMILARITY", { "query": random_node_name }) assert len(search_results) != 0, "The search results list is empty." print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("SUMMARY", { "query": "Work and computers" }) + search_results = await cognee.search("TRAVERSE", { "query": random_node_name }) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted sentences are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search("SUMMARY", { "query": random_node_name }) assert len(search_results) != 0, "Query related summaries don't exist." print("\n\nQuery related summaries exist:\n") for result in search_results: print(f"{result}\n") - search_results = await cognee.search("ADJACENT", { "query": "ROOT" }) - assert len(search_results) != 0, "ROOT node has no neighbours." - print("\n\nROOT node has neighbours.\n") + search_results = await cognee.search("ADJACENT", { "query": random_node_name }) + assert len(search_results) != 0, "Large language model query found no neighbours." + print("\n\Large language model query found neighbours.\n") for result in search_results: print(f"{result}\n")