1227 lines
41 KiB
Python
1227 lines
41 KiB
Python
import asyncio
|
|
|
|
# from datetime import datetime
|
|
import json
|
|
from textwrap import dedent
|
|
from uuid import UUID
|
|
from webbrowser import Error
|
|
from typing import List, Dict, Any, Optional, Tuple, Type, Union
|
|
|
|
from falkordb import FalkorDB
|
|
|
|
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|
GraphDBInterface,
|
|
record_graph_changes,
|
|
NodeData,
|
|
EdgeData,
|
|
Node,
|
|
)
|
|
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
|
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
|
|
from cognee.infrastructure.engine import DataPoint
|
|
|
|
|
|
class IndexSchema(DataPoint):
|
|
"""
|
|
Define a schema for indexing that includes text data and associated metadata.
|
|
|
|
This class inherits from the DataPoint class. It contains a string attribute 'text' and
|
|
a dictionary 'metadata' that specifies the index fields for this schema.
|
|
"""
|
|
|
|
text: str
|
|
|
|
metadata: dict = {"index_fields": ["text"]}
|
|
|
|
|
|
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|
"""
|
|
Manage and interact with a graph database using vector embeddings.
|
|
|
|
Public methods include:
|
|
- query
|
|
- embed_data
|
|
- stringify_properties
|
|
- create_data_point_query
|
|
- create_edge_query
|
|
- create_collection
|
|
- has_collection
|
|
- create_data_points
|
|
- create_vector_index
|
|
- has_vector_index
|
|
- index_data_points
|
|
- add_node
|
|
- add_nodes
|
|
- add_edge
|
|
- add_edges
|
|
- has_edges
|
|
- retrieve
|
|
- extract_node
|
|
- extract_nodes
|
|
- get_connections
|
|
- search
|
|
- batch_search
|
|
- get_graph_data
|
|
- delete_data_points
|
|
- delete_node
|
|
- delete_nodes
|
|
- delete_graph
|
|
- prune
|
|
- get_node
|
|
- get_nodes
|
|
- get_neighbors
|
|
- get_graph_metrics
|
|
- get_document_subgraph
|
|
- get_degree_one_nodes
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
database_url: str,
|
|
database_port: int,
|
|
embedding_engine=EmbeddingEngine,
|
|
):
|
|
self.driver = FalkorDB(
|
|
host=database_url,
|
|
port=database_port,
|
|
)
|
|
self.embedding_engine = embedding_engine
|
|
self.graph_name = "cognee_graph"
|
|
|
|
def query(self, query: str, params: dict = {}):
|
|
"""
|
|
Execute a query against the graph database.
|
|
|
|
Handles exceptions during the query execution by logging errors and re-raising the
|
|
exception.
|
|
|
|
The method can be called only if a valid query string and parameters are provided.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- query (str): The query string to be executed against the graph database.
|
|
- params (dict): A dictionary of parameters to be used in the query. (default {})
|
|
|
|
Returns:
|
|
--------
|
|
|
|
The result of the query execution, returned by the graph database.
|
|
"""
|
|
graph = self.driver.select_graph(self.graph_name)
|
|
|
|
try:
|
|
result = graph.query(query, params)
|
|
return result
|
|
except Exception as e:
|
|
print(f"Error executing query: {e}")
|
|
raise e
|
|
|
|
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
|
"""
|
|
Embed a list of text data into vector representations using the embedding engine.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data (list[str]): A list of strings that should be embedded into vectors.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- list[list[float]]: A list of lists, where each inner list contains float values
|
|
representing the embedded vectors.
|
|
"""
|
|
return await self.embedding_engine.embed_text(data)
|
|
|
|
async def stringify_properties(self, properties: dict) -> str:
|
|
"""
|
|
Convert properties dictionary to a string format suitable for database queries.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- properties (dict): A dictionary containing properties to be converted to string
|
|
format.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- str: A string representation of the properties in the appropriate format.
|
|
"""
|
|
|
|
def parse_value(value):
|
|
"""
|
|
Convert a value to its string representation based on type for database queries.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- value: The value to parse into a string representation.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the string representation of the value in the appropriate format.
|
|
"""
|
|
if type(value) is UUID:
|
|
return f"'{str(value)}'"
|
|
if type(value) is int or type(value) is float:
|
|
return value
|
|
if (
|
|
type(value) is list
|
|
and len(value) > 0
|
|
and type(value[0]) is float
|
|
and len(value) == self.embedding_engine.get_vector_size()
|
|
):
|
|
return f"'vecf32({value})'"
|
|
# if type(value) is datetime:
|
|
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
|
|
if type(value) is dict:
|
|
return f"'{json.dumps(value).replace(chr(39), chr(34))}'"
|
|
if type(value) is str:
|
|
# Escape single quotes and handle special characters
|
|
escaped_value = (
|
|
str(value)
|
|
.replace("'", "\\'")
|
|
.replace('"', '\\"')
|
|
.replace("\n", "\\n")
|
|
.replace("\r", "\\r")
|
|
.replace("\t", "\\t")
|
|
)
|
|
return f"'{escaped_value}'"
|
|
return f"'{str(value)}'"
|
|
|
|
return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
|
|
|
|
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict):
|
|
"""
|
|
Compose a query to create or update a data point in the database.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data_point (DataPoint): An instance of DataPoint containing information about the
|
|
entity.
|
|
- vectorized_values (dict): A dictionary of vectorized values related to the data
|
|
point.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
A tuple containing the query string and parameters dictionary.
|
|
"""
|
|
node_label = type(data_point).__name__
|
|
property_names = DataPoint.get_embeddable_property_names(data_point)
|
|
|
|
properties = {
|
|
**data_point.model_dump(),
|
|
**(
|
|
{
|
|
property_names[index]: (
|
|
vectorized_values[index]
|
|
if index < len(vectorized_values)
|
|
else getattr(data_point, property_name, None)
|
|
)
|
|
for index, property_name in enumerate(property_names)
|
|
}
|
|
),
|
|
}
|
|
|
|
# Clean the properties - remove None values and handle special types
|
|
clean_properties = {}
|
|
for key, value in properties.items():
|
|
if value is not None:
|
|
if isinstance(value, UUID):
|
|
clean_properties[key] = str(value)
|
|
elif isinstance(value, dict):
|
|
clean_properties[key] = json.dumps(value)
|
|
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], float):
|
|
# This is likely a vector - convert to string representation
|
|
clean_properties[key] = f"vecf32({value})"
|
|
else:
|
|
clean_properties[key] = value
|
|
|
|
query = dedent(
|
|
f"""
|
|
MERGE (node:{node_label} {{id: $node_id}})
|
|
SET node += $properties, node.updated_at = timestamp()
|
|
"""
|
|
).strip()
|
|
|
|
params = {"node_id": str(data_point.id), "properties": clean_properties}
|
|
|
|
return query, params
|
|
|
|
def sanitize_relationship_name(self, relationship_name: str) -> str:
|
|
"""
|
|
Sanitize relationship name to be valid for Cypher queries.
|
|
|
|
Parameters:
|
|
-----------
|
|
- relationship_name (str): The original relationship name
|
|
|
|
Returns:
|
|
--------
|
|
- str: A sanitized relationship name valid for Cypher
|
|
"""
|
|
# Replace hyphens, spaces, and other special characters with underscores
|
|
import re
|
|
|
|
sanitized = re.sub(r"[^\w]", "_", relationship_name)
|
|
# Remove consecutive underscores
|
|
sanitized = re.sub(r"_+", "_", sanitized)
|
|
# Remove leading/trailing underscores
|
|
sanitized = sanitized.strip("_")
|
|
# Ensure it starts with a letter or underscore
|
|
if sanitized and not sanitized[0].isalpha() and sanitized[0] != "_":
|
|
sanitized = "_" + sanitized
|
|
return sanitized or "RELATIONSHIP"
|
|
|
|
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
|
"""
|
|
Generate a query to create or update an edge between two nodes in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edge (tuple[str, str, str, dict]): A tuple consisting of source and target node
|
|
IDs, edge type, and edge properties.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- str: A string containing the query to be executed for creating the edge.
|
|
"""
|
|
# Sanitize the relationship name for Cypher compatibility
|
|
sanitized_relationship = self.sanitize_relationship_name(edge[2])
|
|
|
|
# Add the original relationship name to properties
|
|
edge_properties = {**edge[3], "relationship_name": edge[2]}
|
|
properties = await self.stringify_properties(edge_properties)
|
|
properties = f"{{{properties}}}"
|
|
|
|
return dedent(
|
|
f"""
|
|
MERGE (source {{id:'{edge[0]}'}})
|
|
MERGE (target {{id: '{edge[1]}'}})
|
|
MERGE (source)-[edge:{sanitized_relationship} {properties}]->(target)
|
|
ON MATCH SET edge.updated_at = timestamp()
|
|
ON CREATE SET edge.updated_at = timestamp()
|
|
"""
|
|
).strip()
|
|
|
|
async def create_collection(self, collection_name: str):
|
|
"""
|
|
Create a collection in the graph database with the specified name.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection to be created.
|
|
"""
|
|
pass
|
|
|
|
async def has_collection(self, collection_name: str) -> bool:
|
|
"""
|
|
Check if a collection with the specified name exists in the graph database.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection to check for existence.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- bool: Returns true if the collection exists, otherwise false.
|
|
"""
|
|
collections = self.driver.list_graphs()
|
|
|
|
return collection_name in collections
|
|
|
|
async def create_data_points(self, data_points: list[DataPoint]):
|
|
"""
|
|
Add a list of data points to the graph database via batching.
|
|
|
|
Can raise exceptions if there are issues during the database operations.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data_points (list[DataPoint]): A list of DataPoint instances to be inserted into
|
|
the database.
|
|
"""
|
|
embeddable_values = []
|
|
vector_map = {}
|
|
|
|
for data_point in data_points:
|
|
property_names = DataPoint.get_embeddable_property_names(data_point)
|
|
key = str(data_point.id)
|
|
vector_map[key] = {}
|
|
|
|
for property_name in property_names:
|
|
property_value = getattr(data_point, property_name, None)
|
|
|
|
if property_value is not None:
|
|
vector_map[key][property_name] = len(embeddable_values)
|
|
embeddable_values.append(property_value)
|
|
else:
|
|
vector_map[key][property_name] = None
|
|
|
|
vectorized_values = await self.embed_data(embeddable_values)
|
|
|
|
for data_point in data_points:
|
|
vectorized_data = [
|
|
vectorized_values[vector_map[str(data_point.id)][property_name]]
|
|
if vector_map[str(data_point.id)][property_name] is not None
|
|
else None
|
|
for property_name in DataPoint.get_embeddable_property_names(data_point)
|
|
]
|
|
|
|
query, params = await self.create_data_point_query(data_point, vectorized_data)
|
|
self.query(query, params)
|
|
|
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
|
"""
|
|
Create a vector index in the specified graph for a given property if it does not already
|
|
exist.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- index_name (str): The name of the vector index to be created.
|
|
- index_property_name (str): The name of the property on which the vector index will
|
|
be created.
|
|
"""
|
|
graph = self.driver.select_graph(self.graph_name)
|
|
|
|
if not self.has_vector_index(graph, index_name, index_property_name):
|
|
graph.create_node_vector_index(
|
|
index_name, index_property_name, dim=self.embedding_engine.get_vector_size()
|
|
)
|
|
|
|
def has_vector_index(self, graph, index_name: str, index_property_name: str) -> bool:
|
|
"""
|
|
Determine if a vector index exists on the specified property of the given graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- graph: The graph instance to check for the vector index.
|
|
- index_name (str): The name of the index to check for existence.
|
|
- index_property_name (str): The property name associated with the index.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- bool: Returns true if the vector index exists, otherwise false.
|
|
"""
|
|
try:
|
|
indices = graph.list_indices()
|
|
|
|
return any(
|
|
[
|
|
(index[0] == index_name and index_property_name in index[1])
|
|
for index in indices.result_set
|
|
]
|
|
)
|
|
except Error as e:
|
|
print(e)
|
|
return False
|
|
|
|
async def index_data_points(
|
|
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
|
):
|
|
"""
|
|
Index a list of data points in the specified graph database based on properties.
|
|
|
|
To be implemented: does not yet have a defined behavior.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- index_name (str): The name of the index to be created for the data points.
|
|
- index_property_name (str): The property name on which to index the data points.
|
|
- data_points (list[DataPoint]): A list of DataPoint instances to be indexed.
|
|
"""
|
|
pass
|
|
|
|
async def add_node(self, node_id: str, properties: Dict[str, Any]) -> None:
|
|
"""
|
|
Add a single node with specified properties to the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier for the node being added.
|
|
- properties (Dict[str, Any]): A dictionary of properties associated with the node.
|
|
"""
|
|
# Clean the properties - remove None values and handle special types
|
|
clean_properties = {"id": node_id}
|
|
for key, value in properties.items():
|
|
if value is not None:
|
|
if isinstance(value, UUID):
|
|
clean_properties[key] = str(value)
|
|
elif isinstance(value, dict):
|
|
clean_properties[key] = json.dumps(value)
|
|
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], float):
|
|
# This is likely a vector - convert to string representation
|
|
clean_properties[key] = f"vecf32({value})"
|
|
else:
|
|
clean_properties[key] = value
|
|
|
|
query = "MERGE (node {id: $node_id}) SET node += $properties, node.updated_at = timestamp()"
|
|
params = {"node_id": node_id, "properties": clean_properties}
|
|
|
|
self.query(query, params)
|
|
|
|
# Helper methods for DataPoint compatibility
|
|
async def add_data_point_node(self, node: DataPoint):
|
|
"""
|
|
Add a single data point as a node in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node (DataPoint): An instance of DataPoint to be added to the graph.
|
|
"""
|
|
await self.create_data_points([node])
|
|
|
|
async def add_data_point_nodes(self, nodes: list[DataPoint]):
|
|
"""
|
|
Add multiple data points as nodes in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- nodes (list[DataPoint]): A list of DataPoint instances to be added to the graph.
|
|
"""
|
|
await self.create_data_points(nodes)
|
|
|
|
@record_graph_changes
|
|
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
|
"""
|
|
Add multiple nodes to the graph in a single operation.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- nodes (Union[List[Node], List[DataPoint]]): A list of Node tuples or DataPoint objects to be added to the graph.
|
|
"""
|
|
for node in nodes:
|
|
if isinstance(node, tuple) and len(node) == 2:
|
|
# Node is in (node_id, properties) format
|
|
node_id, properties = node
|
|
await self.add_node(node_id, properties)
|
|
elif hasattr(node, "id") and hasattr(node, "model_dump"):
|
|
# Node is a DataPoint object
|
|
await self.add_node(str(node.id), node.model_dump())
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid node format: {node}. Expected tuple (node_id, properties) or DataPoint object."
|
|
)
|
|
|
|
async def add_edge(
|
|
self,
|
|
source_id: str,
|
|
target_id: str,
|
|
relationship_name: str,
|
|
properties: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""
|
|
Create a new edge between two nodes in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- source_id (str): The unique identifier of the source node.
|
|
- target_id (str): The unique identifier of the target node.
|
|
- relationship_name (str): The name of the relationship to be established by the
|
|
edge.
|
|
- properties (Optional[Dict[str, Any]]): Optional dictionary of properties
|
|
associated with the edge. (default None)
|
|
"""
|
|
if properties is None:
|
|
properties = {}
|
|
|
|
edge_tuple = (source_id, target_id, relationship_name, properties)
|
|
query = await self.create_edge_query(edge_tuple)
|
|
self.query(query)
|
|
|
|
@record_graph_changes
|
|
async def add_edges(self, edges: List[EdgeData]) -> None:
|
|
"""
|
|
Add multiple edges to the graph in a single operation.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges (List[EdgeData]): A list of EdgeData objects representing edges to be added.
|
|
"""
|
|
for edge in edges:
|
|
if isinstance(edge, tuple) and len(edge) == 4:
|
|
# Edge is in (source_id, target_id, relationship_name, properties) format
|
|
source_id, target_id, relationship_name, properties = edge
|
|
await self.add_edge(source_id, target_id, relationship_name, properties)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid edge format: {edge}. Expected tuple (source_id, target_id, relationship_name, properties)."
|
|
)
|
|
|
|
async def has_edges(self, edges):
|
|
"""
|
|
Check if the specified edges exist in the graph based on their attributes.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges: A list of edges to check for existence in the graph.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns a list of edge tuples that exist in the graph.
|
|
"""
|
|
existing_edges = []
|
|
for edge in edges:
|
|
exists = await self.has_edge(str(edge[0]), str(edge[1]), edge[2])
|
|
if exists:
|
|
existing_edges.append(edge)
|
|
return existing_edges
|
|
|
|
async def retrieve(self, data_point_ids: list[UUID]):
|
|
"""
|
|
Retrieve data points from the graph based on their IDs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data_point_ids (list[UUID]): A list of UUIDs representing the data points to
|
|
retrieve.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the result set containing the retrieved nodes or an empty list if not found.
|
|
"""
|
|
result = self.query(
|
|
"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
|
{
|
|
"node_ids": [str(data_point) for data_point in data_point_ids],
|
|
},
|
|
)
|
|
return result.result_set
|
|
|
|
async def extract_node(self, data_point_id: UUID):
|
|
"""
|
|
Extract the properties of a single node identified by its data point ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data_point_id (UUID): The UUID of the data point to extract.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the properties of the node if found, otherwise None.
|
|
"""
|
|
result = await self.retrieve([data_point_id])
|
|
result = result[0][0] if len(result[0]) > 0 else None
|
|
return result.properties if result else None
|
|
|
|
async def extract_nodes(self, data_point_ids: list[UUID]):
|
|
"""
|
|
Extract properties of multiple nodes identified by their data point IDs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- data_point_ids (list[UUID]): A list of UUIDs representing the data points to
|
|
extract.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the properties of the nodes in a list.
|
|
"""
|
|
return await self.retrieve(data_point_ids)
|
|
|
|
async def get_connections(self, node_id: UUID) -> list:
|
|
"""
|
|
Retrieve connection details (predecessors and successors) for a given node ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (UUID): The UUID of the node whose connections are to be retrieved.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
- list: Returns a list of tuples representing the connections of the node.
|
|
"""
|
|
predecessors_query = """
|
|
MATCH (node)<-[relation]-(neighbour)
|
|
WHERE node.id = $node_id
|
|
RETURN neighbour, relation, node
|
|
"""
|
|
successors_query = """
|
|
MATCH (node)-[relation]->(neighbour)
|
|
WHERE node.id = $node_id
|
|
RETURN node, relation, neighbour
|
|
"""
|
|
|
|
predecessors, successors = await asyncio.gather(
|
|
self.query(predecessors_query, dict(node_id=node_id)),
|
|
self.query(successors_query, dict(node_id=node_id)),
|
|
)
|
|
|
|
connections = []
|
|
|
|
for neighbour in predecessors:
|
|
neighbour = neighbour["relation"]
|
|
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
|
|
|
for neighbour in successors:
|
|
neighbour = neighbour["relation"]
|
|
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
|
|
|
return connections
|
|
|
|
async def search(
|
|
self,
|
|
collection_name: str,
|
|
query_text: str = None,
|
|
query_vector: list[float] = None,
|
|
limit: int = 10,
|
|
with_vector: bool = False,
|
|
):
|
|
"""
|
|
Search for nodes in a collection based on text or vector query, with optional limitation
|
|
on results.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection in which to search.
|
|
- query_text (str): The text to search for (if using text-based query). (default
|
|
None)
|
|
- query_vector (list[float]): The vector representation of the query if using
|
|
vector-based search. (default None)
|
|
- limit (int): Maximum number of results to return from the search. (default 10)
|
|
- with_vector (bool): Flag indicating whether to return vectors with the search
|
|
results. (default False)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the search results as a result set from the graph database.
|
|
"""
|
|
if query_text is None and query_vector is None:
|
|
raise MissingQueryParameterError()
|
|
|
|
if query_text and not query_vector:
|
|
query_vector = (await self.embed_data([query_text]))[0]
|
|
|
|
# For FalkorDB, let's do a simple property-based search instead of vector search for now
|
|
# since the vector index might not be set up correctly
|
|
if "." in collection_name:
|
|
[label, attribute_name] = collection_name.split(".")
|
|
else:
|
|
# If no dot, treat the whole thing as a property search
|
|
label = ""
|
|
attribute_name = collection_name
|
|
|
|
# Simple text-based search if we have query_text
|
|
if query_text:
|
|
if label:
|
|
query = f"""
|
|
MATCH (n:{label})
|
|
WHERE toLower(toString(n.{attribute_name})) CONTAINS toLower($query_text)
|
|
RETURN n, 1.0 as score
|
|
LIMIT $limit
|
|
"""
|
|
else:
|
|
query = f"""
|
|
MATCH (n)
|
|
WHERE toLower(toString(n.{attribute_name})) CONTAINS toLower($query_text)
|
|
RETURN n, 1.0 as score
|
|
LIMIT $limit
|
|
"""
|
|
|
|
params = {"query_text": query_text, "limit": limit}
|
|
result = self.query(query, params)
|
|
return result.result_set
|
|
else:
|
|
# For vector search, return empty for now since vector indexing needs proper setup
|
|
return []
|
|
|
|
async def batch_search(
|
|
self,
|
|
collection_name: str,
|
|
query_texts: list[str],
|
|
limit: int = None,
|
|
with_vectors: bool = False,
|
|
):
|
|
"""
|
|
Perform batch search across multiple queries based on text inputs and return results
|
|
asynchronously.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection in which to perform the
|
|
searches.
|
|
- query_texts (list[str]): A list of text queries to search for.
|
|
- limit (int): Optional limit for the search results for each query. (default None)
|
|
- with_vectors (bool): Flag indicating whether to return vectors with the results.
|
|
(default False)
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns a list of results for each search query executed in parallel.
|
|
"""
|
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
|
|
|
return await asyncio.gather(
|
|
*[
|
|
self.search(
|
|
collection_name=collection_name,
|
|
query_vector=query_vector,
|
|
limit=limit,
|
|
with_vector=with_vectors,
|
|
)
|
|
for query_vector in query_vectors
|
|
]
|
|
)
|
|
|
|
async def get_graph_data(self):
|
|
"""
|
|
Retrieve all nodes and edges from the graph along with their properties.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns a tuple containing lists of nodes and edges data retrieved from the graph.
|
|
"""
|
|
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
|
|
|
result = self.query(query)
|
|
|
|
nodes = [
|
|
(
|
|
record[2]["id"],
|
|
record[2],
|
|
)
|
|
for record in result.result_set
|
|
]
|
|
|
|
query = """
|
|
MATCH (n)-[r]->(m)
|
|
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
|
"""
|
|
result = self.query(query)
|
|
edges = [
|
|
(
|
|
record[3]["source_node_id"],
|
|
record[3]["target_node_id"],
|
|
record[2],
|
|
record[3],
|
|
)
|
|
for record in result.result_set
|
|
]
|
|
|
|
return (nodes, edges)
|
|
|
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
|
"""
|
|
Remove specified data points from the graph database based on their IDs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- collection_name (str): The name of the collection from which to delete the data
|
|
points.
|
|
- data_point_ids (list[UUID]): A list of UUIDs representing the data points to
|
|
delete.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the result of the deletion operation from the database.
|
|
"""
|
|
return self.query(
|
|
"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
|
{
|
|
"node_ids": [str(data_point) for data_point in data_point_ids],
|
|
},
|
|
)
|
|
|
|
async def delete_node(self, node_id: str) -> None:
|
|
"""
|
|
Delete a specified node from the graph by its ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier for the node to delete.
|
|
"""
|
|
query = f"MATCH (node {{id: '{node_id}'}}) DETACH DELETE node"
|
|
self.query(query)
|
|
|
|
async def delete_nodes(self, node_ids: List[str]) -> None:
|
|
"""
|
|
Delete multiple nodes from the graph by their identifiers.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of unique identifiers for the nodes to delete.
|
|
"""
|
|
for node_id in node_ids:
|
|
await self.delete_node(node_id)
|
|
|
|
async def delete_graph(self):
|
|
"""
|
|
Delete the entire graph along with all its indices and nodes.
|
|
"""
|
|
try:
|
|
graph = self.driver.select_graph(self.graph_name)
|
|
|
|
indices = graph.list_indices()
|
|
for index in indices.result_set:
|
|
for field in index[1]:
|
|
graph.drop_node_vector_index(index[0], field)
|
|
|
|
graph.delete()
|
|
except Exception as e:
|
|
print(f"Error deleting graph: {e}")
|
|
|
|
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
|
"""
|
|
Retrieve a single node from the graph using its ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node to retrieve.
|
|
"""
|
|
result = self.query(
|
|
"MATCH (node) WHERE node.id = $node_id RETURN node",
|
|
{"node_id": node_id},
|
|
)
|
|
|
|
if result.result_set and len(result.result_set) > 0:
|
|
# FalkorDB returns node objects as first element in the result list
|
|
return result.result_set[0][0].properties
|
|
return None
|
|
|
|
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
|
|
"""
|
|
Retrieve multiple nodes from the graph using their IDs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of unique identifiers for the nodes to retrieve.
|
|
"""
|
|
result = self.query(
|
|
"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
|
{"node_ids": node_ids},
|
|
)
|
|
|
|
nodes = []
|
|
if result.result_set:
|
|
for record in result.result_set:
|
|
# FalkorDB returns node objects as first element in each record
|
|
nodes.append(record[0].properties)
|
|
return nodes
|
|
|
|
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
|
"""
|
|
Get all neighboring nodes connected to the specified node.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node for which to retrieve neighbors.
|
|
"""
|
|
result = self.query(
|
|
"MATCH (node)-[]-(neighbor) WHERE node.id = $node_id RETURN DISTINCT neighbor",
|
|
{"node_id": node_id},
|
|
)
|
|
|
|
neighbors = []
|
|
if result.result_set:
|
|
for record in result.result_set:
|
|
# FalkorDB returns neighbor objects as first element in each record
|
|
neighbors.append(record[0].properties)
|
|
return neighbors
|
|
|
|
async def get_edges(self, node_id: str) -> List[EdgeData]:
|
|
"""
|
|
Retrieve all edges that are connected to the specified node.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node whose edges are to be retrieved.
|
|
"""
|
|
result = self.query(
|
|
"""
|
|
MATCH (n)-[r]-(m)
|
|
WHERE n.id = $node_id
|
|
RETURN n.id AS source_id, m.id AS target_id, type(r) AS relationship_name, properties(r) AS properties
|
|
""",
|
|
{"node_id": node_id},
|
|
)
|
|
|
|
edges = []
|
|
if result.result_set:
|
|
for record in result.result_set:
|
|
# FalkorDB returns values by index: source_id, target_id, relationship_name, properties
|
|
edges.append(
|
|
(
|
|
record[0], # source_id
|
|
record[1], # target_id
|
|
record[2], # relationship_name
|
|
record[3], # properties
|
|
)
|
|
)
|
|
return edges
|
|
|
|
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
|
"""
|
|
Verify if an edge exists between two specified nodes.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- source_id (str): Unique identifier of the source node.
|
|
- target_id (str): Unique identifier of the target node.
|
|
- relationship_name (str): Name of the relationship to verify.
|
|
"""
|
|
# Check both the sanitized relationship type and the original name in properties
|
|
sanitized_relationship = self.sanitize_relationship_name(relationship_name)
|
|
|
|
result = self.query(
|
|
f"""
|
|
MATCH (source)-[r:{sanitized_relationship}]->(target)
|
|
WHERE source.id = $source_id AND target.id = $target_id
|
|
AND (r.relationship_name = $relationship_name OR NOT EXISTS(r.relationship_name))
|
|
RETURN COUNT(r) > 0 AS edge_exists
|
|
""",
|
|
{
|
|
"source_id": source_id,
|
|
"target_id": target_id,
|
|
"relationship_name": relationship_name,
|
|
},
|
|
)
|
|
|
|
if result.result_set and len(result.result_set) > 0:
|
|
# FalkorDB returns scalar results as a list, access by index instead of key
|
|
return result.result_set[0][0]
|
|
return False
|
|
|
|
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Fetch metrics and statistics of the graph, possibly including optional details.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- include_optional (bool): Flag indicating whether to include optional metrics or
|
|
not. (default False)
|
|
"""
|
|
# Get basic node and edge counts
|
|
node_result = self.query("MATCH (n) RETURN count(n) AS node_count")
|
|
edge_result = self.query("MATCH ()-[r]->() RETURN count(r) AS edge_count")
|
|
|
|
# FalkorDB returns scalar results as a list, access by index instead of key
|
|
num_nodes = node_result.result_set[0][0] if node_result.result_set else 0
|
|
num_edges = edge_result.result_set[0][0] if edge_result.result_set else 0
|
|
|
|
metrics = {
|
|
"num_nodes": num_nodes,
|
|
"num_edges": num_edges,
|
|
"mean_degree": (2 * num_edges) / num_nodes if num_nodes > 0 else 0,
|
|
"edge_density": num_edges / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
|
|
"num_connected_components": 1, # Simplified for now
|
|
"sizes_of_connected_components": [num_nodes] if num_nodes > 0 else [],
|
|
}
|
|
|
|
if include_optional:
|
|
# Add optional metrics - simplified implementation
|
|
metrics.update(
|
|
{
|
|
"num_selfloops": 0, # Simplified
|
|
"diameter": -1, # Not implemented
|
|
"avg_shortest_path_length": -1, # Not implemented
|
|
"avg_clustering": -1, # Not implemented
|
|
}
|
|
)
|
|
else:
|
|
metrics.update(
|
|
{
|
|
"num_selfloops": -1,
|
|
"diameter": -1,
|
|
"avg_shortest_path_length": -1,
|
|
"avg_clustering": -1,
|
|
}
|
|
)
|
|
|
|
return metrics
|
|
|
|
async def get_document_subgraph(self, content_hash: str):
|
|
"""
|
|
Get a subgraph related to a specific document by content hash.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- content_hash (str): The content hash of the document to find.
|
|
"""
|
|
query = """
|
|
MATCH (d) WHERE d.id CONTAINS $content_hash
|
|
OPTIONAL MATCH (d)<-[:CHUNK_OF]-(c)
|
|
OPTIONAL MATCH (c)-[:HAS_ENTITY]->(e)
|
|
OPTIONAL MATCH (e)-[:IS_INSTANCE_OF]->(et)
|
|
RETURN d AS document,
|
|
COLLECT(DISTINCT c) AS chunks,
|
|
COLLECT(DISTINCT e) AS orphan_entities,
|
|
COLLECT(DISTINCT c) AS made_from_nodes,
|
|
COLLECT(DISTINCT et) AS orphan_types
|
|
"""
|
|
|
|
result = self.query(query, {"content_hash": f"text_{content_hash}"})
|
|
|
|
if not result.result_set or not result.result_set[0]:
|
|
return None
|
|
|
|
# Convert result to dictionary format
|
|
# FalkorDB returns values by index: document, chunks, orphan_entities, made_from_nodes, orphan_types
|
|
record = result.result_set[0]
|
|
return {
|
|
"document": record[0],
|
|
"chunks": record[1],
|
|
"orphan_entities": record[2],
|
|
"made_from_nodes": record[3],
|
|
"orphan_types": record[4],
|
|
}
|
|
|
|
async def get_degree_one_nodes(self, node_type: str):
|
|
"""
|
|
Get all nodes that have only one connection.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_type (str): The type of nodes to filter by, must be 'Entity' or 'EntityType'.
|
|
"""
|
|
if not node_type or node_type not in ["Entity", "EntityType"]:
|
|
raise ValueError("node_type must be either 'Entity' or 'EntityType'")
|
|
|
|
result = self.query(
|
|
f"""
|
|
MATCH (n:{node_type})
|
|
WITH n, COUNT {{ MATCH (n)--() }} as degree
|
|
WHERE degree = 1
|
|
RETURN n
|
|
"""
|
|
)
|
|
|
|
# FalkorDB returns node objects as first element in each record
|
|
return [record[0] for record in result.result_set] if result.result_set else []
|
|
|
|
async def get_nodeset_subgraph(
|
|
self, node_type: Type[Any], node_name: List[str]
|
|
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
|
"""
|
|
Fetch a subgraph consisting of a specific set of nodes and their relationships.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_type (Type[Any]): The type of nodes to include in the subgraph.
|
|
- node_name (List[str]): A list of names of the nodes to include in the subgraph.
|
|
"""
|
|
label = node_type.__name__
|
|
|
|
# Find primary nodes of the specified type and names
|
|
primary_query = f"""
|
|
UNWIND $names AS wantedName
|
|
MATCH (n:{label})
|
|
WHERE n.name = wantedName
|
|
RETURN DISTINCT n.id, properties(n) AS properties
|
|
"""
|
|
|
|
primary_result = self.query(primary_query, {"names": node_name})
|
|
if not primary_result.result_set:
|
|
return [], []
|
|
|
|
# FalkorDB returns values by index: id, properties
|
|
primary_ids = [record[0] for record in primary_result.result_set]
|
|
|
|
# Find neighbors of primary nodes
|
|
neighbor_query = """
|
|
MATCH (n)-[]-(neighbor)
|
|
WHERE n.id IN $ids
|
|
RETURN DISTINCT neighbor.id, properties(neighbor) AS properties
|
|
"""
|
|
|
|
neighbor_result = self.query(neighbor_query, {"ids": primary_ids})
|
|
# FalkorDB returns values by index: id, properties
|
|
neighbor_ids = (
|
|
[record[0] for record in neighbor_result.result_set]
|
|
if neighbor_result.result_set
|
|
else []
|
|
)
|
|
|
|
all_ids = list(set(primary_ids + neighbor_ids))
|
|
|
|
# Get all nodes in the subgraph
|
|
nodes_query = """
|
|
MATCH (n)
|
|
WHERE n.id IN $ids
|
|
RETURN n.id, properties(n) AS properties
|
|
"""
|
|
|
|
nodes_result = self.query(nodes_query, {"ids": all_ids})
|
|
nodes = []
|
|
if nodes_result.result_set:
|
|
for record in nodes_result.result_set:
|
|
# FalkorDB returns values by index: id, properties
|
|
nodes.append((record[0], record[1]))
|
|
|
|
# Get edges between these nodes
|
|
edges_query = """
|
|
MATCH (a)-[r]->(b)
|
|
WHERE a.id IN $ids AND b.id IN $ids
|
|
RETURN a.id AS source_id, b.id AS target_id, type(r) AS relationship_name, properties(r) AS properties
|
|
"""
|
|
|
|
edges_result = self.query(edges_query, {"ids": all_ids})
|
|
edges = []
|
|
if edges_result.result_set:
|
|
for record in edges_result.result_set:
|
|
# FalkorDB returns values by index: source_id, target_id, relationship_name, properties
|
|
edges.append(
|
|
(
|
|
record[0], # source_id
|
|
record[1], # target_id
|
|
record[2], # relationship_name
|
|
record[3], # properties
|
|
)
|
|
)
|
|
|
|
return nodes, edges
|
|
|
|
async def prune(self):
|
|
"""
|
|
Prune the graph by deleting the entire graph structure.
|
|
"""
|
|
await self.delete_graph()
|