Merge remote-tracking branch 'origin/main' into COG-206
This commit is contained in:
commit
2717272403
29 changed files with 558 additions and 242 deletions
|
|
@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify
|
|||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.search.search import search, SearchType
|
||||
from .api.v1.prune import prune
|
||||
|
||||
# Pipelines
|
||||
from .modules import pipelines
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import sentry_sdk
|
||||
from typing import Dict, Any, List, Union, Optional, Literal
|
||||
from typing_extensions import Annotated
|
||||
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
||||
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
|
|
|||
|
|
@ -1,18 +1,11 @@
|
|||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
from fastapi_users import fastapi_users
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
|
||||
from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \
|
||||
get_user_permissions, get_async_session_context, fast_api_users_init
|
||||
# from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
|
||||
# from cognee.infrastructure.databases.relational.user_authentication.users import get_user_permissions, fastapi_users
|
||||
get_async_session_context, fast_api_users_init
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
|
||||
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
||||
|
|
@ -62,8 +55,6 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
|
|||
out = await has_permission_document(active_user.current_user(active=True), file["id"], "write", session)
|
||||
|
||||
if out:
|
||||
|
||||
|
||||
async with update_status_lock:
|
||||
task_status = get_task_status([dataset_name])
|
||||
|
||||
|
|
@ -89,9 +80,9 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
|
|||
root_node_id = "ROOT"
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
save_data_chunks,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from cognee.modules.search.graph.search_adjacent import search_adjacent
|
|||
from cognee.modules.search.vector.search_traverse import search_traverse
|
||||
from cognee.modules.search.graph.search_summary import search_summary
|
||||
from cognee.modules.search.graph.search_similarity import search_similarity
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.shared.utils import send_telemetry
|
||||
|
||||
class SearchType(Enum):
|
||||
|
|
@ -63,9 +62,6 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
|
|||
|
||||
|
||||
async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||
graph_client = await get_graph_engine()
|
||||
graph = graph_client.graph
|
||||
|
||||
search_functions: Dict[SearchType, Callable] = {
|
||||
SearchType.ADJACENT: search_adjacent,
|
||||
SearchType.SUMMARY: search_summary,
|
||||
|
|
@ -81,7 +77,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
|||
search_func = search_functions.get(search_param.search_type)
|
||||
if search_func:
|
||||
# Schedule the coroutine for execution and store the task
|
||||
task = search_func(**search_param.params, graph = graph)
|
||||
task = search_func(**search_param.params)
|
||||
search_tasks.append(task)
|
||||
|
||||
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||
|
|
@ -92,7 +88,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
|||
|
||||
send_telemetry("cognee.search")
|
||||
|
||||
return results
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,17 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async def graph(self):
|
||||
return await self.get_session()
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
results = self.query(
|
||||
"""
|
||||
MATCH (n)
|
||||
WHERE n.id = $node_id
|
||||
RETURN COUNT(n) > 0 AS node_exists
|
||||
""",
|
||||
{"node_id": node_id}
|
||||
)
|
||||
return results[0]["node_exists"] if len(results) > 0 else False
|
||||
|
||||
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
|
||||
node_id = node_id.replace(":", "_")
|
||||
|
||||
|
|
@ -157,6 +168,39 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||
query = f"""
|
||||
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
|
||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
edge_exists = await self.query(query)
|
||||
return edge_exists
|
||||
|
||||
async def has_edges(self, edges):
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
try:
|
||||
params = {
|
||||
"edges": [{
|
||||
"from_node": edge[0],
|
||||
"to_node": edge[1],
|
||||
"relationship_name": edge[2],
|
||||
} for edge in edges],
|
||||
}
|
||||
|
||||
results = await self.query(query, params)
|
||||
return [result["edge_exists"] for result in results]
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||
raise error
|
||||
|
||||
|
||||
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
from_node = from_node.replace(":", "_")
|
||||
|
|
@ -198,8 +242,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
},
|
||||
} for edge in edges]
|
||||
|
||||
results = await self.query(query, dict(edges = edges))
|
||||
return results
|
||||
try:
|
||||
results = await self.query(query, dict(edges = edges))
|
||||
return results
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||
raise error
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
query = """
|
||||
|
|
@ -261,8 +309,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor)
|
||||
RETURN predecessor.id AS id
|
||||
MATCH (node)<-[r]-(predecessor)
|
||||
WHERE node.id = $node_id AND type(r) = $edge_label
|
||||
RETURN predecessor.id AS predecessor_id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
|
|
@ -273,11 +322,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
return [result["predecessor_id"] for result in results]
|
||||
else:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)-[r]->(predecessor)
|
||||
RETURN predecessor.id AS id
|
||||
MATCH (node)<-[r]-(predecessor)
|
||||
WHERE node.id = $node_id
|
||||
RETURN predecessor.id AS predecessor_id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
|
|
@ -287,13 +337,14 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
return [result["predecessor_id"] for result in results]
|
||||
|
||||
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor)
|
||||
RETURN successor.id AS id
|
||||
MATCH (node)-[r]->(successor)
|
||||
WHERE node.id = $node_id AND type(r) = $edge_label
|
||||
RETURN successor.id AS successor_id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
|
|
@ -304,11 +355,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
),
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
return [result["successor_id"] for result in results]
|
||||
else:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)<-[r]-(successor)
|
||||
RETURN successor.id AS id
|
||||
MATCH (node)-[r]->(successor)
|
||||
WHERE node.id = $node_id
|
||||
RETURN successor.id AS successor_id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
|
|
@ -318,12 +370,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
return [result["successor_id"] for result in results]
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> list[str]:
|
||||
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id))
|
||||
predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id))
|
||||
|
||||
return [*results[0], *results[1]]
|
||||
return [*predecessor_ids, *successor_ids]
|
||||
|
||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
query = f"""
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
import aiofiles
|
||||
|
|
@ -25,6 +26,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
self.filename = filename
|
||||
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self.graph.has_node(node_id)
|
||||
|
||||
async def add_node(
|
||||
self,
|
||||
|
|
@ -45,6 +48,18 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
async def get_graph(self):
|
||||
return self.graph
|
||||
|
||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||
return self.graph.has_edge(from_node, to_node, key = edge_label)
|
||||
|
||||
async def has_edges(self, edges):
|
||||
result = []
|
||||
|
||||
for (from_node, to_node, edge_label) in edges:
|
||||
if await self.has_edge(from_node, to_node, edge_label):
|
||||
result.append((from_node, to_node, edge_label))
|
||||
|
||||
return result
|
||||
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: str,
|
||||
|
|
@ -154,7 +169,12 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
neighbour_ids = list(self.graph.neighbors(node_id))
|
||||
predecessor_ids, successor_ids = await asyncio.gather(
|
||||
self.get_predecessor_ids(node_id),
|
||||
self.get_successor_ids(node_id),
|
||||
)
|
||||
|
||||
neighbour_ids = predecessor_ids + successor_ids
|
||||
|
||||
if len(neighbour_ids) == 0:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
return [ScoredResult(
|
||||
id = result["id"],
|
||||
payload = result["payload"],
|
||||
score = 1,
|
||||
score = 0,
|
||||
) for result in results.to_dict("index").values()]
|
||||
|
||||
async def search(
|
||||
|
|
@ -109,7 +109,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 10,
|
||||
limit: int = 5,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
|
|
@ -123,11 +123,25 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
min_value = 100
|
||||
max_value = 0
|
||||
|
||||
for result in result_values:
|
||||
value = float(result["_distance"])
|
||||
if value > max_value:
|
||||
max_value = value
|
||||
if value < min_value:
|
||||
min_value = value
|
||||
|
||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
||||
|
||||
return [ScoredResult(
|
||||
id = str(result["id"]),
|
||||
payload = result["payload"],
|
||||
score = float(result["_distance"]),
|
||||
) for result in results.to_dict("index").values()]
|
||||
score = normalized_values[value_index],
|
||||
) for value_index, result in enumerate(result_values)]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from uuid import UUID
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ScoredResult(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
score: float # Lower score is better
|
||||
payload: Dict[str, Any]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
from typing import List, Dict, Optional
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
logger = logging.getLogger("QDrantAdapter")
|
||||
|
||||
# class CollectionConfig(BaseModel, extra = "forbid"):
|
||||
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
||||
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
||||
|
|
@ -102,14 +105,17 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
points = [convert_to_qdrant_point(point) for point in data_points]
|
||||
|
||||
result = await client.upload_points(
|
||||
collection_name = collection_name,
|
||||
points = points
|
||||
)
|
||||
|
||||
await client.close()
|
||||
|
||||
return result
|
||||
try:
|
||||
result = await client.upload_points(
|
||||
collection_name = collection_name,
|
||||
points = points
|
||||
)
|
||||
return result
|
||||
except Exception as error:
|
||||
logger.error("Error uploading data points to Qdrant: %s", str(error))
|
||||
raise error
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
client = self.get_qdrant_client()
|
||||
|
|
@ -122,7 +128,7 @@ class QDrantAdapter(VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = None,
|
||||
limit: int = 5,
|
||||
with_vector: bool = False
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
logger = logging.getLogger("WeaviateAdapter")
|
||||
|
||||
class WeaviateAdapter(VectorDBInterface):
|
||||
name = "Weaviate"
|
||||
|
|
@ -78,20 +80,25 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
vector = vector
|
||||
)
|
||||
|
||||
|
||||
objects = list(map(convert_to_weaviate_data_points, data_points))
|
||||
data_points = list(map(convert_to_weaviate_data_points, data_points))
|
||||
|
||||
collection = self.get_collection(collection_name)
|
||||
|
||||
with collection.batch.dynamic() as batch:
|
||||
for data_row in objects:
|
||||
batch.add_object(
|
||||
properties = data_row.properties,
|
||||
vector = data_row.vector
|
||||
)
|
||||
|
||||
return
|
||||
# return self.get_collection(collection_name).data.insert_many(objects)
|
||||
try:
|
||||
if len(data_points) > 1:
|
||||
return collection.data.insert_many(data_points)
|
||||
else:
|
||||
return collection.data.insert(data_points[0])
|
||||
# with collection.batch.dynamic() as batch:
|
||||
# for point in data_points:
|
||||
# batch.add_object(
|
||||
# uuid = point.uuid,
|
||||
# properties = point.properties,
|
||||
# vector = point.vector
|
||||
# )
|
||||
except Exception as error:
|
||||
logger.error("Error creating data points: %s", str(error))
|
||||
raise error
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
from weaviate.classes.query import Filter
|
||||
|
|
|
|||
|
|
@ -2,16 +2,16 @@ You are a top-tier algorithm designed for extracting information in structured f
|
|||
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
|
||||
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
|
||||
|
||||
The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
|
||||
The aim is to achieve simplicity and clarity in the knowledge graph.
|
||||
# 1. Labeling Nodes
|
||||
**Consistency**: Ensure you use basic or elementary types for node labels.
|
||||
- For example, when you identify an entity representing a person, always label it as **"Person"**.
|
||||
- Avoid using more specific terms like "Mathematician" or "Scientist".
|
||||
- Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
|
||||
- Don't use too generic terms like "Entity".
|
||||
**Node IDs**: Never utilize integers as node IDs.
|
||||
- Node IDs should be names or human-readable identifiers found in the text.
|
||||
# 2. Handling Numerical Data and Dates
|
||||
- For example, when you identify an entity representing a date, always label it as **"Date"**.
|
||||
- For example, when you identify an entity representing a date, make sure it has type **"Date"**.
|
||||
- Extract the date in the format "YYYY-MM-DD"
|
||||
- If not possible to extract the whole date, extract month or year, or both if available.
|
||||
- **Property Format**: Properties must be in a key-value format.
|
||||
|
|
@ -23,4 +23,4 @@ The aim is to achieve simplicity and clarity in the knowledge graph, making it a
|
|||
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
|
||||
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
|
||||
# 4. Strict Compliance
|
||||
Adhere to the rules strictly. Non-compliance will result in termination"""
|
||||
Adhere to the rules strictly. Non-compliance will result in termination
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
|||
vector_engine = get_vector_engine()
|
||||
|
||||
class Keyword(BaseModel):
|
||||
id: str
|
||||
uuid: str
|
||||
text: str
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
|
|
@ -61,7 +61,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
|||
DataPoint[Keyword](
|
||||
id = str(classification_type_id),
|
||||
payload = Keyword.parse_obj({
|
||||
"id": str(classification_type_id),
|
||||
"uuid": str(classification_type_id),
|
||||
"text": classification_type_label,
|
||||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
|
|
@ -100,7 +100,7 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
|||
DataPoint[Keyword](
|
||||
id = str(classification_subtype_id),
|
||||
payload = Keyword.parse_obj({
|
||||
"id": str(classification_subtype_id),
|
||||
"uuid": str(classification_subtype_id),
|
||||
"text": classification_subtype_label,
|
||||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
|
|
@ -118,9 +118,9 @@ async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_
|
|||
)
|
||||
))
|
||||
edges.append((
|
||||
str(classification_type_id),
|
||||
str(classification_subtype_id),
|
||||
"contains",
|
||||
str(classification_type_id),
|
||||
"is_subtype_of",
|
||||
dict(
|
||||
relationship_name = "contains",
|
||||
source_node_id = str(classification_type_id),
|
||||
|
|
|
|||
|
|
@ -1,25 +1,77 @@
|
|||
import json
|
||||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from datetime import datetime
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||
from ...processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from .extract_knowledge_graph import extract_content_graph
|
||||
|
||||
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
|
||||
class EntityNode(BaseModel):
|
||||
uuid: str
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes]
|
||||
graph_type_node_ids = list(set(type_ids))
|
||||
graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids)
|
||||
existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes}
|
||||
has_collection = await vector_engine.has_collection(collection_name)
|
||||
|
||||
if not has_collection:
|
||||
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
|
||||
|
||||
processed_nodes = {}
|
||||
type_node_edges = []
|
||||
entity_node_edges = []
|
||||
type_entity_edges = []
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
chunk_graph = chunk_graphs[chunk_index]
|
||||
for node in chunk_graph.nodes:
|
||||
type_node_id = generate_node_id(node.type)
|
||||
entity_node_id = generate_node_id(node.id)
|
||||
|
||||
if type_node_id not in processed_nodes:
|
||||
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
|
||||
processed_nodes[type_node_id] = True
|
||||
|
||||
if entity_node_id not in processed_nodes:
|
||||
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
|
||||
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
|
||||
processed_nodes[entity_node_id] = True
|
||||
|
||||
graph_node_edges = [
|
||||
(edge.source_node_id, edge.target_node_id, edge.relationship_name) \
|
||||
for edge in chunk_graph.edges
|
||||
]
|
||||
|
||||
existing_edges = await graph_engine.has_edges([
|
||||
*type_node_edges,
|
||||
*entity_node_edges,
|
||||
*type_entity_edges,
|
||||
*graph_node_edges,
|
||||
])
|
||||
|
||||
existing_edges_map = {}
|
||||
existing_nodes_map = {}
|
||||
|
||||
for edge in existing_edges:
|
||||
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
|
||||
existing_nodes_map[edge[0]] = True
|
||||
|
||||
graph_nodes = []
|
||||
graph_edges = []
|
||||
data_points = []
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
graph = chunk_graphs[chunk_index]
|
||||
|
|
@ -28,90 +80,139 @@ async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model:
|
|||
|
||||
for node in graph.nodes:
|
||||
node_id = generate_node_id(node.id)
|
||||
node_name = generate_name(node.name)
|
||||
|
||||
graph_nodes.append((
|
||||
node_id,
|
||||
dict(
|
||||
id = node_id,
|
||||
chunk_id = str(chunk.chunk_id),
|
||||
document_id = str(chunk.document_id),
|
||||
name = node.name,
|
||||
type = node.type.lower().capitalize(),
|
||||
type_node_id = generate_node_id(node.type)
|
||||
type_node_name = generate_name(node.type)
|
||||
|
||||
if node_id not in existing_nodes_map:
|
||||
node_data = dict(
|
||||
uuid = node_id,
|
||||
name = node_name,
|
||||
type = node_name,
|
||||
description = node.description,
|
||||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
))
|
||||
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
node_id,
|
||||
"contains",
|
||||
dict(
|
||||
relationship_name = "contains",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
graph_nodes.append((
|
||||
node_id,
|
||||
dict(
|
||||
**node_data,
|
||||
properties = json.dumps(node.properties),
|
||||
)
|
||||
))
|
||||
|
||||
type_node_id = generate_node_id(node.type)
|
||||
data_points.append(DataPoint[EntityNode](
|
||||
id = str(uuid5(NAMESPACE_OID, node_id)),
|
||||
payload = node_data,
|
||||
embed_field = "name",
|
||||
))
|
||||
|
||||
if type_node_id not in existing_type_nodes_map:
|
||||
node_name = node.type.lower().capitalize()
|
||||
existing_nodes_map[node_id] = True
|
||||
|
||||
type_node = dict(
|
||||
id = type_node_id,
|
||||
name = node_name,
|
||||
type = node_name,
|
||||
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
node_id,
|
||||
"contains_entity",
|
||||
dict(
|
||||
relationship_name = "contains_entity",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
# Add relationship between entity type and entity itself: "Jake is Person"
|
||||
graph_edges.append((
|
||||
node_id,
|
||||
type_node_id,
|
||||
"is_entity_type",
|
||||
dict(
|
||||
relationship_name = "is_entity_type",
|
||||
source_node_id = type_node_id,
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
if type_node_id not in existing_nodes_map:
|
||||
type_node_data = dict(
|
||||
uuid = type_node_id,
|
||||
name = type_node_name,
|
||||
type = type_node_id,
|
||||
description = type_node_name,
|
||||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
graph_nodes.append((type_node_id, type_node))
|
||||
existing_type_nodes_map[type_node_id] = type_node
|
||||
graph_nodes.append((type_node_id, dict(
|
||||
**type_node_data,
|
||||
properties = json.dumps(node.properties)
|
||||
)))
|
||||
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
type_node_id,
|
||||
"contains_entity_type",
|
||||
dict(
|
||||
relationship_name = "contains_entity_type",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = type_node_id,
|
||||
),
|
||||
))
|
||||
data_points.append(DataPoint[EntityNode](
|
||||
id = str(uuid5(NAMESPACE_OID, type_node_id)),
|
||||
payload = type_node_data,
|
||||
embed_field = "name",
|
||||
))
|
||||
|
||||
# Add relationship between entity type and entity itself: "Jake is Person"
|
||||
graph_edges.append((
|
||||
type_node_id,
|
||||
node_id,
|
||||
"is_entity_type",
|
||||
dict(
|
||||
relationship_name = "is_entity_type",
|
||||
source_node_id = type_node_id,
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
existing_nodes_map[type_node_id] = True
|
||||
|
||||
# Add relationship that came from graphs.
|
||||
for edge in graph.edges:
|
||||
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
generate_node_id(edge.source_node_id),
|
||||
generate_node_id(edge.target_node_id),
|
||||
edge.relationship_name,
|
||||
str(chunk.chunk_id),
|
||||
type_node_id,
|
||||
"contains_entity_type",
|
||||
dict(
|
||||
relationship_name = edge.relationship_name,
|
||||
source_node_id = generate_node_id(edge.source_node_id),
|
||||
target_node_id = generate_node_id(edge.target_node_id),
|
||||
relationship_name = "contains_entity_type",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = type_node_id,
|
||||
),
|
||||
))
|
||||
|
||||
await graph_engine.add_nodes(graph_nodes)
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
# Add relationship that came from graphs.
|
||||
for edge in graph.edges:
|
||||
source_node_id = generate_node_id(edge.source_node_id)
|
||||
target_node_id = generate_node_id(edge.target_node_id)
|
||||
relationship_name = generate_name(edge.relationship_name)
|
||||
edge_key = source_node_id + target_node_id + relationship_name
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
generate_node_id(edge.source_node_id),
|
||||
generate_node_id(edge.target_node_id),
|
||||
edge.relationship_name,
|
||||
dict(
|
||||
relationship_name = generate_name(edge.relationship_name),
|
||||
source_node_id = generate_node_id(edge.source_node_id),
|
||||
target_node_id = generate_node_id(edge.target_node_id),
|
||||
properties = json.dumps(edge.properties),
|
||||
),
|
||||
))
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
if len(data_points) > 0:
|
||||
await vector_engine.create_data_points(collection_name, data_points)
|
||||
|
||||
if len(graph_nodes) > 0:
|
||||
await graph_engine.add_nodes(graph_nodes)
|
||||
|
||||
if len(graph_edges) > 0:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
def generate_name(name: str) -> str:
|
||||
return name.lower().replace(" ", "_").replace("'", "")
|
||||
|
||||
def generate_node_id(node_id: str) -> str:
|
||||
return node_id.upper().replace(" ", "_").replace("'", "")
|
||||
return node_id.lower().replace(" ", "_").replace("'", "")
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .tasks.Task import Task
|
||||
from .operations.run_tasks import run_tasks
|
||||
from .operations.run_parallel import run_tasks_parallel
|
||||
|
|
|
|||
|
|
@ -8,27 +8,29 @@ async def main():
|
|||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(num):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(nums):
|
||||
async def add_one(nums):
|
||||
for num in nums:
|
||||
yield num * 2
|
||||
yield num + 1
|
||||
|
||||
async def add_one_to_batched_data(num):
|
||||
async def multiply_by_two(num):
|
||||
yield num * 2
|
||||
|
||||
async def add_one_single(num):
|
||||
yield num + 1
|
||||
|
||||
pipeline = run_tasks([
|
||||
Task(number_generator, task_config = {"batch_size": 1}),
|
||||
Task(number_generator),
|
||||
Task(add_one, task_config = {"batch_size": 5}),
|
||||
Task(multiply_by_two, task_config = {"batch_size": 1}),
|
||||
Task(add_one_to_batched_data),
|
||||
Task(add_one_single),
|
||||
], 10)
|
||||
|
||||
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
|
||||
index = 0
|
||||
async for result in pipeline:
|
||||
print("\n")
|
||||
print(result)
|
||||
print("\n")
|
||||
assert result == results[index]
|
||||
index += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
import asyncio
|
||||
from queue import Queue
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
|
||||
async def pipeline(data_queue):
|
||||
async def queue_consumer():
|
||||
while not data_queue.is_closed:
|
||||
if not data_queue.empty():
|
||||
yield data_queue.get()
|
||||
else:
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
async def add_one(num):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(num):
|
||||
yield num * 2
|
||||
|
||||
tasks_run = run_tasks([
|
||||
Task(queue_consumer),
|
||||
Task(add_one),
|
||||
Task(multiply_by_two),
|
||||
])
|
||||
|
||||
results = [2, 4, 6, 8, 10, 12, 14, 16, 18]
|
||||
index = 0
|
||||
async for result in tasks_run:
|
||||
print(result)
|
||||
assert result == results[index]
|
||||
index += 1
|
||||
|
||||
async def main():
|
||||
data_queue = Queue()
|
||||
data_queue.is_closed = False
|
||||
|
||||
async def queue_producer():
|
||||
for i in range(0, 10):
|
||||
data_queue.put(i)
|
||||
await asyncio.sleep(0.1)
|
||||
data_queue.is_closed = True
|
||||
|
||||
await asyncio.gather(pipeline(data_queue), queue_producer())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -4,29 +4,30 @@ from ..tasks.Task import Task
|
|||
|
||||
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
||||
|
||||
async def run_tasks(tasks: [Task], data):
|
||||
async def run_tasks(tasks: [Task], data = None):
|
||||
if len(tasks) == 0:
|
||||
yield data
|
||||
return
|
||||
|
||||
args = [data] if data is not None else []
|
||||
|
||||
running_task = tasks[0]
|
||||
batch_size = running_task.task_config["batch_size"]
|
||||
leftover_tasks = tasks[1:]
|
||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
||||
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
|
||||
if inspect.isasyncgenfunction(running_task.executable):
|
||||
logger.info(f"Running async generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Running async generator task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
results = []
|
||||
|
||||
async_iterator = running_task.run(data)
|
||||
async_iterator = running_task.run(*args)
|
||||
|
||||
async for partial_result in async_iterator:
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
||||
if len(results) == next_task_batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
|
@ -37,7 +38,7 @@ async def run_tasks(tasks: [Task], data):
|
|||
|
||||
results = []
|
||||
|
||||
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Finished async generator task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running async generator task: `%s`\n%s\n",
|
||||
|
|
@ -48,15 +49,15 @@ async def run_tasks(tasks: [Task], data):
|
|||
raise error
|
||||
|
||||
elif inspect.isgeneratorfunction(running_task.executable):
|
||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Running generator task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
results = []
|
||||
|
||||
for partial_result in running_task.run(data):
|
||||
for partial_result in running_task.run(*args):
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
||||
if len(results) == next_task_batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
|
@ -67,7 +68,7 @@ async def run_tasks(tasks: [Task], data):
|
|||
|
||||
results = []
|
||||
|
||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Finished generator task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running generator task: `%s`\n%s\n",
|
||||
|
|
@ -78,13 +79,35 @@ async def run_tasks(tasks: [Task], data):
|
|||
raise error
|
||||
|
||||
elif inspect.iscoroutinefunction(running_task.executable):
|
||||
task_result = await running_task.run(data)
|
||||
logger.info("Running coroutine task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
task_result = await running_task.run(*args)
|
||||
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
|
||||
logger.info("Finished coroutine task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running coroutine task: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info = True,
|
||||
)
|
||||
|
||||
elif inspect.isfunction(running_task.executable):
|
||||
task_result = running_task.run(data)
|
||||
logger.info("Running function task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
task_result = running_task.run(*args)
|
||||
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
|
||||
logger.info("Finished function task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running function task: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info = True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,18 @@
|
|||
from typing import Union, Dict
|
||||
import networkx as nx
|
||||
import asyncio
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
|
||||
async def search_adjacent(query: str) -> list[(str, str)]:
|
||||
"""
|
||||
Find the neighbours of a given node in the graph and return their ids and descriptions.
|
||||
|
||||
Parameters:
|
||||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
||||
- query (str): Unused in this implementation but could be used for future enhancements.
|
||||
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
|
||||
- query (str): The query string to filter nodes by.
|
||||
|
||||
Returns:
|
||||
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
|
||||
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
|
||||
"""
|
||||
node_id = other_param.get("node_id") if other_param else query
|
||||
node_id = query
|
||||
|
||||
if node_id is None:
|
||||
return {}
|
||||
|
|
@ -23,16 +20,24 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param:
|
|||
graph_engine = await get_graph_engine()
|
||||
|
||||
exact_node = await graph_engine.extract_node(node_id)
|
||||
if exact_node is not None and "id" in exact_node:
|
||||
neighbours = await graph_engine.get_neighbours(exact_node["id"])
|
||||
|
||||
if exact_node is not None and "uuid" in exact_node:
|
||||
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
|
||||
else:
|
||||
vector_engine = get_vector_engine()
|
||||
collection_name = "classification"
|
||||
data_points = await vector_engine.search(collection_name, query_text = node_id, limit = 5)
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("entities", query_text = query, limit = 10),
|
||||
vector_engine.search("classification", query_text = query, limit = 10),
|
||||
)
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||
|
||||
if len(data_points) == 0:
|
||||
if len(relevant_results) == 0:
|
||||
return []
|
||||
|
||||
neighbours = await graph_engine.get_neighbours(data_points[0].id)
|
||||
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
|
||||
neighbours = []
|
||||
for neighbour_ids in node_neighbours:
|
||||
neighbours.extend(neighbour_ids)
|
||||
|
||||
return [node["name"] for node in neighbours]
|
||||
return neighbours
|
||||
|
|
|
|||
|
|
@ -1,18 +1,15 @@
|
|||
|
||||
import networkx as nx
|
||||
from typing import Union
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
|
||||
|
||||
async def search_cypher(query:str, graph: Union[nx.Graph, any]):
|
||||
async def search_cypher(query: str):
|
||||
"""
|
||||
Use a Cypher query to search the graph and return the results.
|
||||
"""
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider == "neo4j":
|
||||
result = await graph.run(query)
|
||||
graph_engine = await get_graph_engine()
|
||||
result = await graph_engine.graph().run(query)
|
||||
return result
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupported graph engine type.")
|
||||
raise ValueError("Unsupported search type for the used graph engine.")
|
||||
|
|
|
|||
|
|
@ -1,22 +1,17 @@
|
|||
from typing import Union, Dict
|
||||
import networkx as nx
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_similarity(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
|
||||
async def search_similarity(query: str) -> list[str, str]:
|
||||
"""
|
||||
Parameters:
|
||||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
||||
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
|
||||
- query (str): The query string to filter nodes by.
|
||||
|
||||
Returns:
|
||||
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'result' attributes.
|
||||
- list(chunk): A list of objects providing information about the chunks related to query.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
similar_results = await vector_engine.search("chunks", query, limit = 5)
|
||||
results = [{
|
||||
"text": result.payload["text"],
|
||||
"chunk_id": result.payload["chunk_id"],
|
||||
} for result in similar_results]
|
||||
|
||||
results = [result.payload for result in similar_results]
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -1,24 +1,17 @@
|
|||
from typing import Union, Dict
|
||||
import networkx as nx
|
||||
from cognee.shared.data_models import ChunkSummaries
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_summary(query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
|
||||
async def search_summary(query: str) -> list:
|
||||
"""
|
||||
Parameters:
|
||||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
||||
- query (str): The query string to filter nodes by, e.g., 'SUMMARY'.
|
||||
- other_param (str, optional): An additional parameter, unused in this implementation but could be for future enhancements.
|
||||
- query (str): The query string to filter summaries by.
|
||||
|
||||
Returns:
|
||||
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes.
|
||||
- list[str, UUID]: A list of objects providing information about the summaries related to query.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
summaries_results = await vector_engine.search("chunk_summaries", query, limit = 5)
|
||||
summaries = [{
|
||||
"text": summary.payload["text"],
|
||||
"chunk_id": summary.payload["chunk_id"],
|
||||
} for summary in summaries_results]
|
||||
|
||||
summaries = [summary.payload for summary in summaries_results]
|
||||
|
||||
return summaries
|
||||
|
|
|
|||
|
|
@ -1,21 +1,36 @@
|
|||
import asyncio
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_traverse(query: str, graph): # graph must be there in order to be compatible with generic call
|
||||
async def search_traverse(query: str):
|
||||
node_id = query
|
||||
rules = set()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
results = await vector_engine.search("classification", query_text = query, limit = 10)
|
||||
exact_node = await graph_engine.extract_node(node_id)
|
||||
|
||||
rules = []
|
||||
if exact_node is not None and "uuid" in exact_node:
|
||||
edges = await graph_engine.get_edges(exact_node["uuid"])
|
||||
|
||||
if len(results) > 0:
|
||||
for result in results:
|
||||
graph_node_id = result.id
|
||||
for edge in edges:
|
||||
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
||||
else:
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("entities", query_text = query, limit = 10),
|
||||
vector_engine.search("classification", query_text = query, limit = 10),
|
||||
)
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||
|
||||
edges = await graph_engine.get_edges(graph_node_id)
|
||||
if len(relevant_results) > 0:
|
||||
for result in relevant_results:
|
||||
graph_node_id = result.id
|
||||
|
||||
for edge in edges:
|
||||
rules.append(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
||||
edges = await graph_engine.get_edges(graph_node_id)
|
||||
|
||||
return rules
|
||||
for edge in edges:
|
||||
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
||||
|
||||
return list(rules)
|
||||
|
|
|
|||
5
cognee/pipelines.py
Normal file
5
cognee/pipelines.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Don't add any more code here, this file is used only for the purpose
|
||||
# of enabling imports from `cognee.pipelines` module.
|
||||
# `from cognee.pipelines import Task` for example.
|
||||
|
||||
from .modules.pipelines import *
|
||||
|
|
@ -10,12 +10,14 @@ class Node(BaseModel):
|
|||
name: str
|
||||
type: str
|
||||
description: str
|
||||
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the node.")
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Edge in a knowledge graph."""
|
||||
source_node_id: str
|
||||
target_node_id: str
|
||||
relationship_name: str
|
||||
properties: Optional[Dict[str, Any]] = Field(None, description = "A dictionary of properties associated with the edge.")
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""Knowledge graph."""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
|
|
@ -38,21 +37,32 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name], root_node_id = "ROOT")
|
||||
|
||||
search_results = await cognee.search("TRAVERSE", { "query": "Text" })
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
||||
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\n\nQuery related summaries exist:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("ADJACENT", { "query": "Articles" })
|
||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
||||
print("\n\nROOT node has neighbours.\n")
|
||||
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||
print("\n\Large language model query found neighbours.\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
|
|
|
|||
|
|
@ -33,21 +33,32 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
||||
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\n\nQuery related summaries exist:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
|
||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
||||
print("\n\nROOT node has neighbours.\n")
|
||||
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||
print("\n\Large language model query found neighbours.\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
|
|
|
|||
|
|
@ -34,21 +34,32 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
||||
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\n\nQuery related summaries exist:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
|
||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
||||
print("\n\nROOT node has neighbours.\n")
|
||||
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||
print("\n\Large language model query found neighbours.\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
|
|
|
|||
|
|
@ -32,21 +32,32 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
search_results = await cognee.search("TRAVERSE", { "query": "Artificial intelligence" })
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entities", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search("SIMILARITY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": "Work and computers" })
|
||||
search_results = await cognee.search("TRAVERSE", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("SUMMARY", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\n\nQuery related summaries exist:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search("ADJACENT", { "query": "ROOT" })
|
||||
assert len(search_results) != 0, "ROOT node has no neighbours."
|
||||
print("\n\nROOT node has neighbours.\n")
|
||||
search_results = await cognee.search("ADJACENT", { "query": random_node_name })
|
||||
assert len(search_results) != 0, "Large language model query found no neighbours."
|
||||
print("\n\Large language model query found neighbours.\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue