Merge remote-tracking branch 'origin/main' into COG-206
This commit is contained in:
commit
2717272403
29 changed files with 558 additions and 242 deletions
|
|
@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify
|
||||||
from .api.v1.datasets.datasets import datasets
|
from .api.v1.datasets.datasets import datasets
|
||||||
from .api.v1.search.search import search, SearchType
|
from .api.v1.search.search import search, SearchType
|
||||||
from .api.v1.prune import prune
|
from .api.v1.prune import prune
|
||||||
|
|
||||||
|
# Pipelines
|
||||||
|
from .modules import pipelines
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import logging
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from typing import Dict, Any, List, Union, Optional, Literal
|
from typing import Dict, Any, List, Union, Optional, Literal
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
|
||||||
from fastapi.responses import JSONResponse, FileResponse
|
from fastapi.responses import JSONResponse, FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from fastapi_users import fastapi_users
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
|
|
||||||
from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \
|
from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \
|
||||||
get_user_permissions, get_async_session_context, fast_api_users_init
|
get_async_session_context, fast_api_users_init
|
||||||
# from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
|
|
||||||
# from cognee.infrastructure.databases.relational.user_authentication.users import get_user_permissions, fastapi_users
|
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
|
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
|
||||||
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
||||||
|
|
@ -62,8 +55,6 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
|
||||||
out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session)
|
out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session)
|
||||||
|
|
||||||
if out:
|
if out:
|
||||||
|
|
||||||
|
|
||||||
async with update_status_lock:
|
async with update_status_lock:
|
||||||
task_status = get_task_status([dataset_name])
|
task_status = get_task_status([dataset_name])
|
||||||
|
|
||||||
|
|
@ -89,9 +80,9 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
|
||||||
root_node_id = "ROOT"
|
root_node_id = "ROOT"
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||||
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
|
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
|
||||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||||
Task(
|
Task(
|
||||||
save_data_chunks,
|
save_data_chunks,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
|
||||||
from cognee.modules.search.vector.search_traverse import search_traverse
|
from cognee.modules.search.vector.search_traverse import search_traverse
|
||||||
from cognee.modules.search.graph.search_summary import search_summary
|
from cognee.modules.search.graph.search_summary import search_summary
|
||||||
from cognee.modules.search.graph.search_similarity import search_similarity
|
from cognee.modules.search.graph.search_similarity import search_similarity
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
|
|
||||||
class SearchType(Enum):
|
class SearchType(Enum):
|
||||||
|
|
@ -63,9 +62,6 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
|
||||||
|
|
||||||
|
|
||||||
async def specific_search(query_params: List[SearchParameters]) -> List:
|
async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||||
graph_client = await get_graph_engine()
|
|
||||||
graph = graph_client.graph
|
|
||||||
|
|
||||||
search_functions: Dict[SearchType, Callable] = {
|
search_functions: Dict[SearchType, Callable] = {
|
||||||
SearchType.ADJACENT: search_adjacent,
|
SearchType.ADJACENT: search_adjacent,
|
||||||
SearchType.SUMMARY: search_summary,
|
SearchType.SUMMARY: search_summary,
|
||||||
|
|
@ -81,7 +77,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||||
search_func = search_functions.get(search_param.search_type)
|
search_func = search_functions.get(search_param.search_type)
|
||||||
if search_func:
|
if search_func:
|
||||||
# Schedule the coroutine for execution and store the task
|
# Schedule the coroutine for execution and store the task
|
||||||
task = search_func(**search_param.params, graph = graph)
|
task = search_func(**search_param.params)
|
||||||
search_tasks.append(task)
|
search_tasks.append(task)
|
||||||
|
|
||||||
# Use asyncio.gather to run all scheduled tasks concurrently
|
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||||
|
|
@ -92,7 +88,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||||
|
|
||||||
send_telemetry("cognee.search")
|
send_telemetry("cognee.search")
|
||||||
|
|
||||||
return results
|
return results[0] if len(results) == 1 else results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,17 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
async def graph(self):
|
async def graph(self):
|
||||||
return await self.get_session()
|
return await self.get_session()
|
||||||
|
|
||||||
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
results = self.query(
|
||||||
|
"""
|
||||||
|
MATCH (n)
|
||||||
|
WHERE n.id = $node_id
|
||||||
|
RETURN COUNT(n) > 0 AS node_exists
|
||||||
|
""",
|
||||||
|
{"node_id": node_id}
|
||||||
|
)
|
||||||
|
return results[0]["node_exists"] if len(results) > 0 else False
|
||||||
|
|
||||||
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
|
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
|
||||||
node_id = node_id.replace(":", "_")
|
node_id = node_id.replace(":", "_")
|
||||||
|
|
||||||
|
|
@ -157,6 +168,39 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
|
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||||
|
query = f"""
|
||||||
|
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
|
||||||
|
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||||
|
"""
|
||||||
|
|
||||||
|
edge_exists = await self.query(query)
|
||||||
|
return edge_exists
|
||||||
|
|
||||||
|
async def has_edges(self, edges):
|
||||||
|
query = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (a)-[r]->(b)
|
||||||
|
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||||
|
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
params = {
|
||||||
|
"edges": [{
|
||||||
|
"from_node": edge[0],
|
||||||
|
"to_node": edge[1],
|
||||||
|
"relationship_name": edge[2],
|
||||||
|
} for edge in edges],
|
||||||
|
}
|
||||||
|
|
||||||
|
results = await self.query(query, params)
|
||||||
|
return [result["edge_exists"] for result in results]
|
||||||
|
except Neo4jError as error:
|
||||||
|
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
|
||||||
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||||
serialized_properties = self.serialize_properties(edge_properties)
|
serialized_properties = self.serialize_properties(edge_properties)
|
||||||
from_node = from_node.replace(":", "_")
|
from_node = from_node.replace(":", "_")
|
||||||
|
|
@ -198,8 +242,12 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
},
|
},
|
||||||
} for edge in edges]
|
} for edge in edges]
|
||||||
|
|
||||||
results = await self.query(query, dict(edges = edges))
|
try:
|
||||||
return results
|
results = await self.query(query, dict(edges = edges))
|
||||||
|
return results
|
||||||
|
except Neo4jError as error:
|
||||||
|
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||||
|
raise error
|
||||||
|
|
||||||
async def get_edges(self, node_id: str):
|
async def get_edges(self, node_id: str):
|
||||||
query = """
|
query = """
|
||||||
|
|
@ -261,8 +309,9 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||||
if edge_label is not None:
|
if edge_label is not None:
|
||||||
query = """
|
query = """
|
||||||
MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor)
|
MATCH (node)<-[r]-(predecessor)
|
||||||
RETURN predecessor.id AS id
|
WHERE node.id = $node_id AND type(r) = $edge_label
|
||||||
|
RETURN predecessor.id AS predecessor_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await self.query(
|
results = await self.query(
|
||||||
|
|
@ -273,11 +322,12 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result["id"] for result in results]
|
return [result["predecessor_id"] for result in results]
|
||||||
else:
|
else:
|
||||||
query = """
|
query = """
|
||||||
MATCH (node:`{node_id}`)-[r]->(predecessor)
|
MATCH (node)<-[r]-(predecessor)
|
||||||
RETURN predecessor.id AS id
|
WHERE node.id = $node_id
|
||||||
|
RETURN predecessor.id AS predecessor_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await self.query(
|
results = await self.query(
|
||||||
|
|
@ -287,13 +337,14 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result["id"] for result in results]
|
return [result["predecessor_id"] for result in results]
|
||||||
|
|
||||||
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||||
if edge_label is not None:
|
if edge_label is not None:
|
||||||
query = """
|
query = """
|
||||||
MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor)
|
MATCH (node)-[r]->(successor)
|
||||||
RETURN successor.id AS id
|
WHERE node.id = $node_id AND type(r) = $edge_label
|
||||||
|
RETURN successor.id AS successor_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await self.query(
|
results = await self.query(
|
||||||
|
|
@ -304,11 +355,12 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result["id"] for result in results]
|
return [result["successor_id"] for result in results]
|
||||||
else:
|
else:
|
||||||
query = """
|
query = """
|
||||||
MATCH (node:`{node_id}`)<-[r]-(successor)
|
MATCH (node)-[r]->(successor)
|
||||||
RETURN successor.id AS id
|
WHERE node.id = $node_id
|
||||||
|
RETURN successor.id AS successor_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await self.query(
|
results = await self.query(
|
||||||
|
|
@ -318,12 +370,12 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result["id"] for result in results]
|
return [result["successor_id"] for result in results]
|
||||||
|
|
||||||
async def get_neighbours(self, node_id: str) -> list[str]:
|
async def get_neighbours(self, node_id: str) -> list[str]:
|
||||||
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id))
|
predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id))
|
||||||
|
|
||||||
return [*results[0], *results[1]]
|
return [*predecessor_ids, *successor_ids]
|
||||||
|
|
||||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||||
query = f"""
|
query = f"""
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
@ -25,6 +26,8 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
|
||||||
|
|
||||||
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
return self.graph.has_node(node_id)
|
||||||
|
|
||||||
async def add_node(
|
async def add_node(
|
||||||
self,
|
self,
|
||||||
|
|
@ -45,6 +48,18 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
async def get_graph(self):
|
async def get_graph(self):
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
|
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||||
|
return self.graph.has_edge(from_node, to_node, key = edge_label)
|
||||||
|
|
||||||
|
async def has_edges(self, edges):
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for (from_node, to_node, edge_label) in edges:
|
||||||
|
if await self.has_edge(from_node, to_node, edge_label):
|
||||||
|
result.append((from_node, to_node, edge_label))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
self,
|
self,
|
||||||
from_node: str,
|
from_node: str,
|
||||||
|
|
@ -154,7 +169,12 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
if not self.graph.has_node(node_id):
|
if not self.graph.has_node(node_id):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
neighbour_ids = list(self.graph.neighbors(node_id))
|
predecessor_ids, successor_ids = await asyncio.gather(
|
||||||
|
self.get_predecessor_ids(node_id),
|
||||||
|
self.get_successor_ids(node_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
neighbour_ids = predecessor_ids + successor_ids
|
||||||
|
|
||||||
if len(neighbour_ids) == 0:
|
if len(neighbour_ids) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = result["id"],
|
id = result["id"],
|
||||||
payload = result["payload"],
|
payload = result["payload"],
|
||||||
score = 1,
|
score = 0,
|
||||||
) for result in results.to_dict("index").values()]
|
) for result in results.to_dict("index").values()]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
|
|
@ -109,7 +109,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: str = None,
|
query_text: str = None,
|
||||||
query_vector: List[float] = None,
|
query_vector: List[float] = None,
|
||||||
limit: int = 10,
|
limit: int = 5,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
|
|
@ -123,11 +123,25 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||||
|
|
||||||
|
result_values = list(results.to_dict("index").values())
|
||||||
|
|
||||||
|
min_value = 100
|
||||||
|
max_value = 0
|
||||||
|
|
||||||
|
for result in result_values:
|
||||||
|
value = float(result["_distance"])
|
||||||
|
if value > max_value:
|
||||||
|
max_value = value
|
||||||
|
if value < min_value:
|
||||||
|
min_value = value
|
||||||
|
|
||||||
|
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
||||||
|
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = str(result["id"]),
|
id = str(result["id"]),
|
||||||
payload = result["payload"],
|
payload = result["payload"],
|
||||||
score = float(result["_distance"]),
|
score = normalized_values[value_index],
|
||||||
) for result in results.to_dict("index").values()]
|
) for value_index, result in enumerate(result_values)]
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
from uuid import UUID
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class ScoredResult(BaseModel):
|
class ScoredResult(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
score: float
|
score: float # Lower score is better
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
|
import logging
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..models.DataPoint import DataPoint
|
from ..models.DataPoint import DataPoint
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger("QDrantAdapter")
|
||||||
|
|
||||||
# class CollectionConfig(BaseModel, extra = "forbid"):
|
# class CollectionConfig(BaseModel, extra = "forbid"):
|
||||||
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
||||||
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
||||||
|
|
@ -102,14 +105,17 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
points = [convert_to_qdrant_point(point) for point in data_points]
|
points = [convert_to_qdrant_point(point) for point in data_points]
|
||||||
|
|
||||||
result = await client.upload_points(
|
try:
|
||||||
collection_name = collection_name,
|
result = await client.upload_points(
|
||||||
points = points
|
collection_name = collection_name,
|
||||||
)
|
points = points
|
||||||
|
)
|
||||||
await client.close()
|
return result
|
||||||
|
except Exception as error:
|
||||||
return result
|
logger.error("Error uploading data points to Qdrant: %s", str(error))
|
||||||
|
raise error
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||||
client = self.get_qdrant_client()
|
client = self.get_qdrant_client()
|
||||||
|
|
@ -122,7 +128,7 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: Optional[List[float]] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = None,
|
limit: int = 5,
|
||||||
with_vector: bool = False
|
with_vector: bool = False
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..models.DataPoint import DataPoint
|
from ..models.DataPoint import DataPoint
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger("WeaviateAdapter")
|
||||||
|
|
||||||
class WeaviateAdapter(VectorDBInterface):
|
class WeaviateAdapter(VectorDBInterface):
|
||||||
name = "Weaviate"
|
name = "Weaviate"
|
||||||
|
|
@ -78,20 +80,25 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
vector = vector
|
vector = vector
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data_points = list(map(convert_to_weaviate_data_points, data_points))
|
||||||
objects = list(map(convert_to_weaviate_data_points, data_points))
|
|
||||||
|
|
||||||
collection = self.get_collection(collection_name)
|
collection = self.get_collection(collection_name)
|
||||||
|
|
||||||
with collection.batch.dynamic() as batch:
|
try:
|
||||||
for data_row in objects:
|
if len(data_points) > 1:
|
||||||
batch.add_object(
|
return collection.data.insert_many(data_points)
|
||||||
properties = data_row.properties,
|
else:
|
||||||
vector = data_row.vector
|
return collection.data.insert(data_points[0])
|
||||||
)
|
# with collection.batch.dynamic() as batch:
|
||||||
|
# for point in data_points:
|
||||||
return
|
# batch.add_object(
|
||||||
# return self.get_collection(collection_name).data.insert_many(objects)
|
# uuid = point.uuid,
|
||||||
|
# properties = point.properties,
|
||||||
|
# vector = point.vector
|
||||||
|
# )
|
||||||
|
except Exception as error:
|
||||||
|
logger.error("Error creating data points: %s", str(error))
|
||||||
|
raise error
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||||
from weaviate.classes.query import Filter
|
from weaviate.classes.query import Filter
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,16 @@ You are a top-tier algorithm designed for extracting information in structured f
|
||||||
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
|
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
|
||||||
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
|
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
|
||||||
|
|
||||||
The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
|
The aim is to achieve simplicity and clarity in the knowledge graph.
|
||||||
# 1. Labeling Nodes
|
# 1. Labeling Nodes
|
||||||
**Consistency**: Ensure you use basic or elementary types for node labels.
|
**Consistency**: Ensure you use basic or elementary types for node labels.
|
||||||
- For example, when you identify an entity representing a person, always label it as **"Person"**.
|
- For example, when you identify an entity representing a person, always label it as **"Person"**.
|
||||||
- Avoid using more specific terms like "Mathematician" or "Scientist".
|
- Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
|
||||||
- Don't use too generic terms like "Entity".
|
- Don't use too generic terms like "Entity".
|
||||||
**Node IDs**: Never utilize integers as node IDs.
|
**Node IDs**: Never utilize integers as node IDs.
|
||||||
- Node IDs should be names or human-readable identifiers found in the text.
|
- Node IDs should be names or human-readable identifiers found in the text.
|
||||||
# 2. Handling Numerical Data and Dates
|
# 2. Handling Numerical Data and Dates
|
||||||
- For example, when you identify an entity representing a date, always label it as **"Date"**.
|
- For example, when you identify an entity representing a date, make sure it has type **"Date"**.
|
||||||
- Extract the date in the format "YYYY-MM-DD"
|
- Extract the date in the format "YYYY-MM-DD"
|
||||||
- If not possible to extract the whole date, extract month or year, or both if available.
|
- If not possible to extract the whole date, extract month or year, or both if available.
|
||||||
- **Property Format**: Properties must be in a key-value format.
|
- **Property Format**: Properties must be in a key-value format.
|
||||||
|
|
@ -23,4 +23,4 @@ The aim is to achieve simplicity and clarity in the knowledge graph, making it a
|
||||||
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
|
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
|
||||||
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
|
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
|
||||||
# 4. Strict Compliance
|
# 4. Strict Compliance
|
||||||
Adhere to the rules strictly. Non-compliance will result in termination"""
|
Adhere to the rules strictly. Non-compliance will result in termination
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
class Keyword(BaseModel):
|
class Keyword(BaseModel):
|
||||||
id: str
|
uuid: str
|
||||||
text: str
|
text: str
|
||||||
chunk_id: str
|
chunk_id: str
|
||||||
document_id: str
|
document_id: str
|
||||||
|
|
@ -61,7 +61,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
||||||
DataPoint[Keyword](
|
DataPoint[Keyword](
|
||||||
id = str(classification_type_id),
|
id = str(classification_type_id),
|
||||||
payload = Keyword.parse_obj({
|
payload = Keyword.parse_obj({
|
||||||
"id": str(classification_type_id),
|
"uuid": str(classification_type_id),
|
||||||
"text": classification_type_label,
|
"text": classification_type_label,
|
||||||
"chunk_id": str(data_chunk.chunk_id),
|
"chunk_id": str(data_chunk.chunk_id),
|
||||||
"document_id": str(data_chunk.document_id),
|
"document_id": str(data_chunk.document_id),
|
||||||
|
|
@ -100,7 +100,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
||||||
DataPoint[Keyword](
|
DataPoint[Keyword](
|
||||||
id = str(classification_subtype_id),
|
id = str(classification_subtype_id),
|
||||||
payload = Keyword.parse_obj({
|
payload = Keyword.parse_obj({
|
||||||
"id": str(classification_subtype_id),
|
"uuid": str(classification_subtype_id),
|
||||||
"text": classification_subtype_label,
|
"text": classification_subtype_label,
|
||||||
"chunk_id": str(data_chunk.chunk_id),
|
"chunk_id": str(data_chunk.chunk_id),
|
||||||
"document_id": str(data_chunk.document_id),
|
"document_id": str(data_chunk.document_id),
|
||||||
|
|
@ -118,9 +118,9 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
||||||
)
|
)
|
||||||
))
|
))
|
||||||
edges.append((
|
edges.append((
|
||||||
str(classification_type_id),
|
|
||||||
str(classification_subtype_id),
|
str(classification_subtype_id),
|
||||||
"contains",
|
str(classification_type_id),
|
||||||
|
"is_subtype_of",
|
||||||
dict(
|
dict(
|
||||||
relationship_name = "contains",
|
relationship_name = "contains",
|
||||||
source_node_id = str(classification_type_id),
|
source_node_id = str(classification_type_id),
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,77 @@
|
||||||
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from uuid import uuid5, NAMESPACE_OID
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||||
from ...processing.chunk_types.DocumentChunk import DocumentChunk
|
from ...processing.chunk_types.DocumentChunk import DocumentChunk
|
||||||
from .extract_knowledge_graph import extract_content_graph
|
from .extract_knowledge_graph import extract_content_graph
|
||||||
|
|
||||||
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
|
class EntityNode(BaseModel):
|
||||||
|
uuid: str
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
description: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
|
||||||
chunk_graphs = await asyncio.gather(
|
chunk_graphs = await asyncio.gather(
|
||||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes]
|
has_collection = await vector_engine.has_collection(collection_name)
|
||||||
graph_type_node_ids = list(set(type_ids))
|
|
||||||
graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids)
|
if not has_collection:
|
||||||
existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes}
|
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
|
||||||
|
|
||||||
|
processed_nodes = {}
|
||||||
|
type_node_edges = []
|
||||||
|
entity_node_edges = []
|
||||||
|
type_entity_edges = []
|
||||||
|
|
||||||
|
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||||
|
chunk_graph = chunk_graphs[chunk_index]
|
||||||
|
for node in chunk_graph.nodes:
|
||||||
|
type_node_id = generate_node_id(node.type)
|
||||||
|
entity_node_id = generate_node_id(node.id)
|
||||||
|
|
||||||
|
if type_node_id not in processed_nodes:
|
||||||
|
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
|
||||||
|
processed_nodes[type_node_id] = True
|
||||||
|
|
||||||
|
if entity_node_id not in processed_nodes:
|
||||||
|
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
|
||||||
|
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
|
||||||
|
processed_nodes[entity_node_id] = True
|
||||||
|
|
||||||
|
graph_node_edges = [
|
||||||
|
(edge.source_node_id, edge.target_node_id, edge.relationship_name) \
|
||||||
|
for edge in chunk_graph.edges
|
||||||
|
]
|
||||||
|
|
||||||
|
existing_edges = await graph_engine.has_edges([
|
||||||
|
*type_node_edges,
|
||||||
|
*entity_node_edges,
|
||||||
|
*type_entity_edges,
|
||||||
|
*graph_node_edges,
|
||||||
|
])
|
||||||
|
|
||||||
|
existing_edges_map = {}
|
||||||
|
existing_nodes_map = {}
|
||||||
|
|
||||||
|
for edge in existing_edges:
|
||||||
|
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
|
||||||
|
existing_nodes_map[edge[0]] = True
|
||||||
|
|
||||||
graph_nodes = []
|
graph_nodes = []
|
||||||
graph_edges = []
|
graph_edges = []
|
||||||
|
data_points = []
|
||||||
|
|
||||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||||
graph = chunk_graphs[chunk_index]
|
graph = chunk_graphs[chunk_index]
|
||||||
|
|
@ -28,90 +80,139 @@ async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model:
|
||||||
|
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
node_id = generate_node_id(node.id)
|
node_id = generate_node_id(node.id)
|
||||||
|
node_name = generate_name(node.name)
|
||||||
|
|
||||||
graph_nodes.append((
|
type_node_id = generate_node_id(node.type)
|
||||||
node_id,
|
type_node_name = generate_name(node.type)
|
||||||
dict(
|
|
||||||
id = node_id,
|
if node_id not in existing_nodes_map:
|
||||||
chunk_id = str(chunk.chunk_id),
|
node_data = dict(
|
||||||
document_id = str(chunk.document_id),
|
uuid = node_id,
|
||||||
name = node.name,
|
name = node_name,
|
||||||
type = node.type.lower().capitalize(),
|
type = node_name,
|
||||||
description = node.description,
|
description = node.description,
|
||||||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
)
|
)
|
||||||
))
|
|
||||||
|
|
||||||
graph_edges.append((
|
graph_nodes.append((
|
||||||
str(chunk.chunk_id),
|
node_id,
|
||||||
node_id,
|
dict(
|
||||||
"contains",
|
**node_data,
|
||||||
dict(
|
properties = json.dumps(node.properties),
|
||||||
relationship_name = "contains",
|
)
|
||||||
source_node_id = str(chunk.chunk_id),
|
))
|
||||||
target_node_id = node_id,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
|
|
||||||
type_node_id = generate_node_id(node.type)
|
data_points.append(DataPoint[EntityNode](
|
||||||
|
id = str(uuid5(NAMESPACE_OID, node_id)),
|
||||||
|
payload = node_data,
|
||||||
|
embed_field = "name",
|
||||||
|
))
|
||||||
|
|
||||||
if type_node_id not in existing_type_nodes_map:
|
existing_nodes_map[node_id] = True
|
||||||
node_name = node.type.lower().capitalize()
|
|
||||||
|
|
||||||
type_node = dict(
|
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
|
||||||
id = type_node_id,
|
|
||||||
name = node_name,
|
if edge_key not in existing_edges_map:
|
||||||
type = node_name,
|
graph_edges.append((
|
||||||
|
str(chunk.chunk_id),
|
||||||
|
node_id,
|
||||||
|
"contains_entity",
|
||||||
|
dict(
|
||||||
|
relationship_name = "contains_entity",
|
||||||
|
source_node_id = str(chunk.chunk_id),
|
||||||
|
target_node_id = node_id,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
# Add relationship between entity type and entity itself: "Jake is Person"
|
||||||
|
graph_edges.append((
|
||||||
|
node_id,
|
||||||
|
type_node_id,
|
||||||
|
"is_entity_type",
|
||||||
|
dict(
|
||||||
|
relationship_name = "is_entity_type",
|
||||||
|
source_node_id = type_node_id,
|
||||||
|
target_node_id = node_id,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
existing_edges_map[edge_key] = True
|
||||||
|
|
||||||
|
if type_node_id not in existing_nodes_map:
|
||||||
|
type_node_data = dict(
|
||||||
|
uuid = type_node_id,
|
||||||
|
name = type_node_name,
|
||||||
|
type = type_node_id,
|
||||||
|
description = type_node_name,
|
||||||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_nodes.append((type_node_id, type_node))
|
graph_nodes.append((type_node_id, dict(
|
||||||
existing_type_nodes_map[type_node_id] = type_node
|
**type_node_data,
|
||||||
|
properties = json.dumps(node.properties)
|
||||||
|
)))
|
||||||
|
|
||||||
graph_edges.append((
|
data_points.append(DataPoint[EntityNode](
|
||||||
str(chunk.chunk_id),
|
id = str(uuid5(NAMESPACE_OID, type_node_id)),
|
||||||
type_node_id,
|
payload = type_node_data,
|
||||||
"contains_entity_type",
|
embed_field = "name",
|
||||||
dict(
|
))
|
||||||
relationship_name = "contains_entity_type",
|
|
||||||
source_node_id = str(chunk.chunk_id),
|
|
||||||
target_node_id = type_node_id,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add relationship between entity type and entity itself: "Jake is Person"
|
existing_nodes_map[type_node_id] = True
|
||||||
graph_edges.append((
|
|
||||||
type_node_id,
|
|
||||||
node_id,
|
|
||||||
"is_entity_type",
|
|
||||||
dict(
|
|
||||||
relationship_name = "is_entity_type",
|
|
||||||
source_node_id = type_node_id,
|
|
||||||
target_node_id = node_id,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add relationship that came from graphs.
|
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
|
||||||
for edge in graph.edges:
|
|
||||||
|
if edge_key not in existing_edges_map:
|
||||||
graph_edges.append((
|
graph_edges.append((
|
||||||
generate_node_id(edge.source_node_id),
|
str(chunk.chunk_id),
|
||||||
generate_node_id(edge.target_node_id),
|
type_node_id,
|
||||||
edge.relationship_name,
|
"contains_entity_type",
|
||||||
dict(
|
dict(
|
||||||
relationship_name = edge.relationship_name,
|
relationship_name = "contains_entity_type",
|
||||||
source_node_id = generate_node_id(edge.source_node_id),
|
source_node_id = str(chunk.chunk_id),
|
||||||
target_node_id = generate_node_id(edge.target_node_id),
|
target_node_id = type_node_id,
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|
||||||
await graph_engine.add_nodes(graph_nodes)
|
existing_edges_map[edge_key] = True
|
||||||
|
|
||||||
await graph_engine.add_edges(graph_edges)
|
# Add relationship that came from graphs.
|
||||||
|
for edge in graph.edges:
|
||||||
|
source_node_id = generate_node_id(edge.source_node_id)
|
||||||
|
target_node_id = generate_node_id(edge.target_node_id)
|
||||||
|
relationship_name = generate_name(edge.relationship_name)
|
||||||
|
edge_key = source_node_id + target_node_id + relationship_name
|
||||||
|
|
||||||
|
if edge_key not in existing_edges_map:
|
||||||
|
graph_edges.append((
|
||||||
|
generate_node_id(edge.source_node_id),
|
||||||
|
generate_node_id(edge.target_node_id),
|
||||||
|
edge.relationship_name,
|
||||||
|
dict(
|
||||||
|
relationship_name = generate_name(edge.relationship_name),
|
||||||
|
source_node_id = generate_node_id(edge.source_node_id),
|
||||||
|
target_node_id = generate_node_id(edge.target_node_id),
|
||||||
|
properties = json.dumps(edge.properties),
|
||||||
|
),
|
||||||
|
))
|
||||||
|
existing_edges_map[edge_key] = True
|
||||||
|
|
||||||
|
if len(data_points) > 0:
|
||||||
|
await vector_engine.create_data_points(collection_name, data_points)
|
||||||
|
|
||||||
|
if len(graph_nodes) > 0:
|
||||||
|
await graph_engine.add_nodes(graph_nodes)
|
||||||
|
|
||||||
|
if len(graph_edges) > 0:
|
||||||
|
await graph_engine.add_edges(graph_edges)
|
||||||
|
|
||||||
return data_chunks
|
return data_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def generate_name(name: str) -> str:
|
||||||
|
return name.lower().replace(" ", "_").replace("'", "")
|
||||||
|
|
||||||
def generate_node_id(node_id: str) -> str:
|
def generate_node_id(node_id: str) -> str:
|
||||||
return node_id.upper().replace(" ", "_").replace("'", "")
|
return node_id.lower().replace(" ", "_").replace("'", "")
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
|
from .tasks.Task import Task
|
||||||
from .operations.run_tasks import run_tasks
|
from .operations.run_tasks import run_tasks
|
||||||
from .operations.run_parallel import run_tasks_parallel
|
from .operations.run_parallel import run_tasks_parallel
|
||||||
|
|
|
||||||
|
|
@ -8,27 +8,29 @@ async def main():
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
yield i + 1
|
yield i + 1
|
||||||
|
|
||||||
async def add_one(num):
|
async def add_one(nums):
|
||||||
yield num + 1
|
|
||||||
|
|
||||||
async def multiply_by_two(nums):
|
|
||||||
for num in nums:
|
for num in nums:
|
||||||
yield num * 2
|
yield num + 1
|
||||||
|
|
||||||
async def add_one_to_batched_data(num):
|
async def multiply_by_two(num):
|
||||||
|
yield num * 2
|
||||||
|
|
||||||
|
async def add_one_single(num):
|
||||||
yield num + 1
|
yield num + 1
|
||||||
|
|
||||||
pipeline = run_tasks([
|
pipeline = run_tasks([
|
||||||
Task(number_generator, task_config = {"batch_size": 1}),
|
Task(number_generator),
|
||||||
Task(add_one, task_config = {"batch_size": 5}),
|
Task(add_one, task_config = {"batch_size": 5}),
|
||||||
Task(multiply_by_two, task_config = {"batch_size": 1}),
|
Task(multiply_by_two, task_config = {"batch_size": 1}),
|
||||||
Task(add_one_to_batched_data),
|
Task(add_one_single),
|
||||||
], 10)
|
], 10)
|
||||||
|
|
||||||
|
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
|
||||||
|
index = 0
|
||||||
async for result in pipeline:
|
async for result in pipeline:
|
||||||
print("\n")
|
|
||||||
print(result)
|
print(result)
|
||||||
print("\n")
|
assert result == results[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
import asyncio
|
||||||
|
from queue import Queue
|
||||||
|
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||||
|
from cognee.modules.pipelines.tasks.Task import Task
|
||||||
|
|
||||||
|
async def pipeline(data_queue):
|
||||||
|
async def queue_consumer():
|
||||||
|
while not data_queue.is_closed:
|
||||||
|
if not data_queue.empty():
|
||||||
|
yield data_queue.get()
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
async def add_one(num):
|
||||||
|
yield num + 1
|
||||||
|
|
||||||
|
async def multiply_by_two(num):
|
||||||
|
yield num * 2
|
||||||
|
|
||||||
|
tasks_run = run_tasks([
|
||||||
|
Task(queue_consumer),
|
||||||
|
Task(add_one),
|
||||||
|
Task(multiply_by_two),
|
||||||
|
])
|
||||||
|
|
||||||
|
results = [2, 4, 6, 8, 10, 12, 14, 16, 18]
|
||||||
|
index = 0
|
||||||
|
async for result in tasks_run:
|
||||||
|
print(result)
|
||||||
|
assert result == results[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
data_queue = Queue()
|
||||||
|
data_queue.is_closed = False
|
||||||
|
|
||||||
|
async def queue_producer():
|
||||||
|
for i in range(0, 10):
|
||||||
|
data_queue.put(i)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
data_queue.is_closed = True
|
||||||
|
|
||||||
|
await asyncio.gather(pipeline(data_queue), queue_producer())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -4,29 +4,30 @@ from ..tasks.Task import Task
|
||||||
|
|
||||||
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
||||||
|
|
||||||
async def run_tasks(tasks: [Task], data):
|
async def run_tasks(tasks: [Task], data = None):
|
||||||
if len(tasks) == 0:
|
if len(tasks) == 0:
|
||||||
yield data
|
yield data
|
||||||
return
|
return
|
||||||
|
|
||||||
|
args = [data] if data is not None else []
|
||||||
|
|
||||||
running_task = tasks[0]
|
running_task = tasks[0]
|
||||||
batch_size = running_task.task_config["batch_size"]
|
|
||||||
leftover_tasks = tasks[1:]
|
leftover_tasks = tasks[1:]
|
||||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
||||||
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(running_task.executable):
|
if inspect.isasyncgenfunction(running_task.executable):
|
||||||
logger.info(f"Running async generator task: `{running_task.executable.__name__}`")
|
logger.info("Running async generator task: `%s`", running_task.executable.__name__)
|
||||||
try:
|
try:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
async_iterator = running_task.run(data)
|
async_iterator = running_task.run(*args)
|
||||||
|
|
||||||
async for partial_result in async_iterator:
|
async for partial_result in async_iterator:
|
||||||
results.append(partial_result)
|
results.append(partial_result)
|
||||||
|
|
||||||
if len(results) == batch_size:
|
if len(results) == next_task_batch_size:
|
||||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -37,7 +38,7 @@ async def run_tasks(tasks: [Task], data):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`")
|
logger.info("Finished async generator task: `%s`", running_task.executable.__name__)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Error occurred while running async generator task: `%s`\n%s\n",
|
"Error occurred while running async generator task: `%s`\n%s\n",
|
||||||
|
|
@ -48,15 +49,15 @@ async def run_tasks(tasks: [Task], data):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
elif inspect.isgeneratorfunction(running_task.executable):
|
elif inspect.isgeneratorfunction(running_task.executable):
|
||||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
logger.info("Running generator task: `%s`", running_task.executable.__name__)
|
||||||
try:
|
try:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for partial_result in running_task.run(data):
|
for partial_result in running_task.run(*args):
|
||||||
results.append(partial_result)
|
results.append(partial_result)
|
||||||
|
|
||||||
if len(results) == batch_size:
|
if len(results) == next_task_batch_size:
|
||||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -67,7 +68,7 @@ async def run_tasks(tasks: [Task], data):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
logger.info("Finished generator task: `%s`", running_task.executable.__name__)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Error occurred while running generator task: `%s`\n%s\n",
|
"Error occurred while running generator task: `%s`\n%s\n",
|
||||||
|
|
@ -78,13 +79,35 @@ async def run_tasks(tasks: [Task], data):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
elif inspect.iscoroutinefunction(running_task.executable):
|
elif inspect.iscoroutinefunction(running_task.executable):
|
||||||
task_result = await running_task.run(data)
|
logger.info("Running coroutine task: `%s`", running_task.executable.__name__)
|
||||||
|
try:
|
||||||
|
task_result = await running_task.run(*args)
|
||||||
|
|
||||||
async for result in run_tasks(leftover_tasks, task_result):
|
async for result in run_tasks(leftover_tasks, task_result):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
logger.info("Finished coroutine task: `%s`", running_task.executable.__name__)
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(
|
||||||
|
"Error occurred while running coroutine task: `%s`\n%s\n",
|
||||||
|
running_task.executable.__name__,
|
||||||
|
str(error),
|
||||||
|
exc_info = True,
|
||||||
|
)
|
||||||
|
|
||||||
elif inspect.isfunction(running_task.executable):
|
elif inspect.isfunction(running_task.executable):
|
||||||
task_result = running_task.run(data)
|
logger.info("Running function task: `%s`", running_task.executable.__name__)
|
||||||
|
try:
|
||||||
|
task_result = running_task.run(*args)
|
||||||
|
|
||||||
async for result in run_tasks(leftover_tasks, task_result):
|
async for result in run_tasks(leftover_tasks, task_result):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
logger.info("Finished function task: `%s`", running_task.executable.__name__)
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(
|
||||||
|
"Error occurred while running function task: `%s`\n%s\n",
|
||||||
|
running_task.executable.__name__,
|
||||||
|
str(error),
|
||||||
|
exc_info = True,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,18 @@
|
||||||
from typing import Union, Dict
|
import asyncio
|
||||||
import networkx as nx
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
|
async def search_adjacent(query: str) -> list[(str, str)]:
|
||||||
"""
|
"""
|
||||||
Find the neighbours of a given node in the graph and return their ids and descriptions.
|
Find the neighbours of a given node in the graph and return their ids and descriptions.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
- query (str): The query string to filter nodes by.
|
||||||
- query (str): Unused in this implementation but could be used for future enhancements.
|
|
||||||
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
|
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
|
||||||
"""
|
"""
|
||||||
node_id = other_param.get("node_id") if other_param else query
|
node_id = query
|
||||||
|
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -23,16 +20,24 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
exact_node = await graph_engine.extract_node(node_id)
|
exact_node = await graph_engine.extract_node(node_id)
|
||||||
if exact_node is not None and "id" in exact_node:
|
|
||||||
neighbours = await graph_engine.get_neighbours(exact_node["id"])
|
if exact_node is not None and "uuid" in exact_node:
|
||||||
|
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
|
||||||
else:
|
else:
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
collection_name = "classification"
|
results = await asyncio.gather(
|
||||||
data_points = await vector_engine.search(collection_name, query_text = node_id, limit = 5)
|
vector_engine.search("entities", query_text = query, limit = 10),
|
||||||
|
vector_engine.search("classification", query_text = query, limit = 10),
|
||||||
|
)
|
||||||
|
results = [*results[0], *results[1]]
|
||||||
|
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||||
|
|
||||||
if len(data_points) == 0:
|
if len(relevant_results) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
neighbours = await graph_engine.get_neighbours(data_points[0].id)
|
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
|
||||||
|
neighbours = []
|
||||||
|
for neighbour_ids in node_neighbours:
|
||||||
|
neighbours.extend(neighbour_ids)
|
||||||
|
|
||||||
return [node["name"] for node in neighbours]
|
return neighbours
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,15 @@
|
||||||
|
|
||||||
import networkx as nx
|
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
|
||||||
from typing import Union
|
|
||||||
from cognee.shared.data_models import GraphDBType
|
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
|
||||||
|
|
||||||
async def search_cypher(query:str, graph: Union[nx.Graph, any]):
|
async def search_cypher(query: str):
|
||||||
"""
|
"""
|
||||||
Use a Cypher query to search the graph and return the results.
|
Use a Cypher query to search the graph and return the results.
|
||||||
"""
|
"""
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
|
|
||||||
if graph_config.graph_database_provider == "neo4j":
|
if graph_config.graph_database_provider == "neo4j":
|
||||||
result = await graph.run(query)
|
graph_engine = await get_graph_engine()
|
||||||
|
result = await graph_engine.graph().run(query)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported graph engine type.")
|
raise ValueError("Unsupported search type for the used graph engine.")
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,17 @@
|
||||||
from typing import Union, Dict
|
|
||||||
import networkx as nx
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def search_similarity(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
|
async def search_similarity(query: str) -> list[str, str]:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
- query (str): The query string to filter nodes by.
|
||||||
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'result' attributes.
|
- list(chunk): A list of objects providing information about the chunks related to query.
|
||||||
"""
|
"""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
similar_results = await vector_engine.search("chunks", query, limit = 5)
|
similar_results = await vector_engine.search("chunks", query, limit = 5)
|
||||||
results = [{
|
|
||||||
"text": result.payload["text"],
|
results = [result.payload for result in similar_results]
|
||||||
"chunk_id": result.payload["chunk_id"],
|
|
||||||
} for result in similar_results]
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,17 @@
|
||||||
from typing import Union, Dict
|
|
||||||
import networkx as nx
|
|
||||||
from cognee.shared.data_models import ChunkSummaries
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def search_summary(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
|
async def search_summary(query: str) -> list:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
- query (str): The query string to filter summaries by.
|
||||||
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
|
|
||||||
- other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes.
|
- list[str, UUID]: A list of objects providing information about the summaries related to query.
|
||||||
"""
|
"""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5)
|
summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5)
|
||||||
summaries = [{
|
|
||||||
"text": summary.payload["text"],
|
summaries = [summary.payload for summary in summaries_results]
|
||||||
"chunk_id": summary.payload["chunk_id"],
|
|
||||||
} for summary in summaries_results]
|
|
||||||
|
|
||||||
return summaries
|
return summaries
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,36 @@
|
||||||
|
import asyncio
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def search_traverse(query: str, graph): # graph must be there in order to be compatible with generic call
|
async def search_traverse(query: str):
|
||||||
|
node_id = query
|
||||||
|
rules = set()
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
results = await vector_engine.search("classification", query_text = query, limit = 10)
|
exact_node = await graph_engine.extract_node(node_id)
|
||||||
|
|
||||||
rules = []
|
if exact_node is not None and "uuid" in exact_node:
|
||||||
|
edges = await graph_engine.get_edges(exact_node["uuid"])
|
||||||
|
|
||||||
if len(results) > 0:
|
for edge in edges:
|
||||||
for result in results:
|
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
||||||
graph_node_id = result.id
|
else:
|
||||||
|
results = await asyncio.gather(
|
||||||
|
vector_engine.search("entities", query_text = query, limit = 10),
|
||||||
|
vector_engine.search("classification", query_text = query, limit = 10),
|
||||||
|
)
|
||||||
|
results = [*results[0], *results[1]]
|
||||||
|
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||||
|
|
||||||
edges = await graph_engine.get_edges(graph_node_id)
|
if len(relevant_results) > 0:
|
||||||
|
for result in relevant_results:
|
||||||
|
graph_node_id = result.id
|
||||||
|
|
||||||
for edge in edges:
|
edges = await graph_engine.get_edges(graph_node_id)
|
||||||
rules.append(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
|
||||||
|
|
||||||
return rules
|
for edge in edges:
|
||||||
|
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
||||||
|
|
||||||
|
return list(rules)
|
||||||
|
|
|
||||||
5
cognee/pipelines.py
Normal file
5
cognee/pipelines.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Don't add any more code here, this file is used only for the purpose
|
||||||
|
# of enabling imports from `cognee.pipelines` module.
|
||||||
|
# `from cognee.pipelines import Task` for example.
|
||||||
|
|
||||||
|
from .modules.pipelines import *
|
||||||
|
|
@ -10,12 +10,14 @@ class Node(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str
|
||||||
description: str
|
description: str
|
||||||
|
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the node.")
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
"""Edge in a knowledge graph."""
|
"""Edge in a knowledge graph."""
|
||||||
source_node_id: str
|
source_node_id: str
|
||||||
target_node_id: str
|
target_node_id: str
|
||||||
relationship_name: str
|
relationship_name: str
|
||||||
|
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the edge.")
|
||||||
|
|
||||||
class KnowledgeGraph(BaseModel):
|
class KnowledgeGraph(BaseModel):
|
||||||
"""Knowledge graph."""
|
"""Knowledge graph."""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
@ -38,21 +37,32 @@ async def main():
|
||||||
|
|
||||||
await cognee.cognify([dataset_name], root_node_id = "ROOT")
|
await cognee.cognify([dataset_name], root_node_id = "ROOT")
|
||||||
|
|
||||||
search_results = await cognee.search("TRAVERSE", { "query": "Text" })
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||||
|
random_node_name = random_node.payload["name"]
|
||||||
|
|
||||||
|
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||||
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
print("\n\nExtracted sentences are:\n")
|
||||||
|
for result in search_results:
|
||||||
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||||
print("\n\nQuery related summaries exist:\n")
|
print("\n\nQuery related summaries exist:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("ADJACENT", { "query": "Articles" })
|
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||||
print("\n\nROOT node has neighbours.\n")
|
print("\n\Large language model query found neighbours.\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,21 +33,32 @@ async def main():
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||||
|
random_node_name = random_node.payload["name"]
|
||||||
|
|
||||||
|
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||||
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
print("\n\nExtracted sentences are:\n")
|
||||||
|
for result in search_results:
|
||||||
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||||
print("\n\nQuery related summaries exist:\n")
|
print("\n\nQuery related summaries exist:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
|
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||||
print("\n\nROOT node has neighbours.\n")
|
print("\n\Large language model query found neighbours.\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,21 +34,32 @@ async def main():
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||||
|
random_node_name = random_node.payload["name"]
|
||||||
|
|
||||||
|
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||||
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
print("\n\nExtracted sentences are:\n")
|
||||||
|
for result in search_results:
|
||||||
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||||
print("\n\nQuery related summaries exist:\n")
|
print("\n\nQuery related summaries exist:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
|
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||||
print("\n\nROOT node has neighbours.\n")
|
print("\n\Large language model query found neighbours.\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,21 +32,32 @@ async def main():
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||||
|
random_node_name = random_node.payload["name"]
|
||||||
|
|
||||||
|
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||||
|
assert len(search_results) != 0, "The search results list is empty."
|
||||||
|
print("\n\nExtracted sentences are:\n")
|
||||||
|
for result in search_results:
|
||||||
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||||
print("\n\nQuery related summaries exist:\n")
|
print("\n\nQuery related summaries exist:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
|
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||||
print("\n\nROOT node has neighbours.\n")
|
print("\n\Large language model query found neighbours.\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue