feat: add FalkorDB integration
This commit is contained in:
parent
73372df31e
commit
a2b1087c84
29 changed files with 180 additions and 427 deletions
|
|
@ -8,6 +8,7 @@ from uuid import UUID
|
|||
from neo4j import AsyncSession
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
logger = logging.getLogger("Neo4jAdapter")
|
||||
|
|
@ -62,10 +63,11 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async def add_node(self, node: DataPoint):
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
|
||||
query = dedent("""MERGE (node {id: $node_id})
|
||||
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
||||
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId""")
|
||||
query = """MERGE (node {id: $node_id})
|
||||
ON CREATE SET node += $properties
|
||||
ON MATCH SET node += $properties
|
||||
ON MATCH SET node.updated_at = timestamp()
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||
|
||||
params = {
|
||||
"node_id": str(node.id),
|
||||
|
|
@ -78,8 +80,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
query = """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n {id: node.node_id})
|
||||
ON CREATE SET n += node.properties, n.updated_at = timestamp()
|
||||
ON MATCH SET n += node.properties, n.updated_at = timestamp()
|
||||
ON CREATE SET n += node.properties
|
||||
ON MATCH SET n += node.properties
|
||||
ON MATCH SET n.updated_at = timestamp()
|
||||
WITH n, node.node_id AS label
|
||||
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
|
||||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||
|
|
@ -134,9 +137,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
return await self.query(query, params)
|
||||
|
||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||||
query = """
|
||||
MATCH (from_node)-[relationship]->(to_node)
|
||||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
|
||||
query = f"""
|
||||
MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`)
|
||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
|
|
@ -176,18 +178,17 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
|
||||
query = dedent("""MATCH (from_node {id: $from_node}),
|
||||
(to_node {id: $to_node})
|
||||
MERGE (from_node)-[r]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||
RETURN r
|
||||
""")
|
||||
query = f"""MATCH (from_node:`{str(from_node)}`
|
||||
{{id: $from_node}}),
|
||||
(to_node:`{str(to_node)}` {{id: $to_node}})
|
||||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp()
|
||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||
RETURN r"""
|
||||
|
||||
params = {
|
||||
"from_node": str(from_node),
|
||||
"to_node": str(to_node),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": serialized_properties
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,10 +30,6 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
def __init__(self, filename = "cognee_graph.pkl"):
|
||||
self.filename = filename
|
||||
|
||||
async def get_graph_data(self):
|
||||
await self.load_graph_from_file()
|
||||
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
|
||||
|
||||
async def query(self, query: str, params: dict):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from falkordb import FalkorDB
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -162,35 +161,6 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
async def extract_nodes(self, data_point_ids: list[str]):
|
||||
return await self.retrieve(data_point_ids)
|
||||
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
predecessors_query = """
|
||||
MATCH (node)<-[relation]-(neighbour)
|
||||
WHERE node.id = $node_id
|
||||
RETURN neighbour, relation, node
|
||||
"""
|
||||
successors_query = """
|
||||
MATCH (node)-[relation]->(neighbour)
|
||||
WHERE node.id = $node_id
|
||||
RETURN node, relation, neighbour
|
||||
"""
|
||||
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.query(predecessors_query, dict(node_id = node_id)),
|
||||
self.query(successors_query, dict(node_id = node_id)),
|
||||
)
|
||||
|
||||
connections = []
|
||||
|
||||
for neighbour in predecessors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
|
||||
for neighbour in successors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
|
||||
return connections
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
|
|
|
|||
|
|
@ -1,113 +0,0 @@
|
|||
import asyncio
|
||||
from falkordb import FalkorDB
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
|
||||
class FalcorDBAdapter(VectorDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
graph_database_url: str,
|
||||
graph_database_port: int,
|
||||
embedding_engine = EmbeddingEngine,
|
||||
):
|
||||
self.driver = FalkorDB(
|
||||
host = graph_database_url,
|
||||
port = graph_database_port)
|
||||
self.embedding_engine = embedding_engine
|
||||
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
collections = self.driver.list_graphs()
|
||||
|
||||
return collection_name in collections
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
||||
self.driver.select_graph(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
graph = self.driver.select_graph(collection_name)
|
||||
|
||||
def stringify_properties(properties: dict) -> str:
|
||||
return ",".join(f"{key}:'{value}'" for key, value in properties.items())
|
||||
|
||||
def create_data_point_query(data_point: DataPoint):
|
||||
node_label = type(data_point.payload).__name__
|
||||
node_properties = stringify_properties(data_point.payload.dict())
|
||||
|
||||
return f"""CREATE (:{node_label} {{{node_properties}}})"""
|
||||
|
||||
query = " ".join([create_data_point_query(data_point) for data_point in data_points])
|
||||
|
||||
graph.query(query)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
graph = self.driver.select_graph(collection_name)
|
||||
|
||||
return graph.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
||||
{
|
||||
"node_ids": data_point_ids,
|
||||
},
|
||||
)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: list[float] = None,
|
||||
limit: int = 10,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
graph = self.driver.select_graph(collection_name)
|
||||
|
||||
query = f"""
|
||||
CALL db.idx.vector.queryNodes(
|
||||
null,
|
||||
'text',
|
||||
{limit},
|
||||
{query_vector}
|
||||
) YIELD node, score
|
||||
"""
|
||||
|
||||
result = graph.query(query)
|
||||
|
||||
return result
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_texts: list[str],
|
||||
limit: int = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
*[self.search(
|
||||
collection_name = collection_name,
|
||||
query_vector = query_vector,
|
||||
limit = limit,
|
||||
with_vector = with_vectors,
|
||||
) for query_vector in query_vectors]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
graph = self.driver.select_graph(collection_name)
|
||||
|
||||
return graph.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
||||
{
|
||||
"node_ids": data_point_ids,
|
||||
},
|
||||
)
|
||||
|
|
@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
from .serialize_data import serialize_data
|
||||
from .serialize_datetime import serialize_datetime
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
|
@ -79,10 +79,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
async def create_data_points(
|
||||
self, collection_name: str, data_points: List[DataPoint]
|
||||
):
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name = collection_name,
|
||||
payload_schema = type(data_points[0]),
|
||||
async with self.get_async_session() as session:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name=collection_name,
|
||||
payload_schema=type(data_points[0]),
|
||||
)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
|
|
@ -102,10 +107,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
pgvector_data_points = [
|
||||
PGVectorDataPoint(
|
||||
id=data_point.id,
|
||||
vector=data_vectors[data_index],
|
||||
payload=serialize_datetime(data_point.model_dump()),
|
||||
)
|
||||
for (data_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
|
||||
pgvector_data_points = [
|
||||
PGVectorDataPoint(
|
||||
|
|
@ -127,7 +136,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = data_point.get_embeddable_data(),
|
||||
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
|
|
@ -197,19 +206,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items:
|
||||
# TODO: Add normalization of similarity score
|
||||
vector_list.append(vector)
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(
|
||||
id = UUID(str(row.id)),
|
||||
payload = row.payload,
|
||||
score = row.similarity
|
||||
) for row in vector_list
|
||||
]
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(
|
||||
id = UUID(row.id),
|
||||
payload = row.payload,
|
||||
score = row.similarity
|
||||
) for row in vector_list
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from uuid import UUID
|
|||
from typing import List, Dict, Optional
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
|||
logger = logging.getLogger("WeaviateAdapter")
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
uuid: str
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
|
|
@ -88,10 +89,8 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
def convert_to_weaviate_data_points(data_point: DataPoint):
|
||||
vector = data_vectors[data_points.index(data_point)]
|
||||
properties = data_point.model_dump()
|
||||
|
||||
if "id" in properties:
|
||||
properties["uuid"] = str(data_point.id)
|
||||
del properties["id"]
|
||||
properties["uuid"] = properties["id"]
|
||||
del properties["id"]
|
||||
|
||||
return DataObject(
|
||||
uuid = data_point.id,
|
||||
|
|
@ -131,8 +130,8 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = data_point.get_embeddable_data(),
|
||||
uuid = str(data_point.id),
|
||||
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
|
|
@ -179,9 +178,9 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
|
||||
return [
|
||||
ScoredResult(
|
||||
id = UUID(str(result.uuid)),
|
||||
id = UUID(result.id),
|
||||
payload = result.properties,
|
||||
score = 1 - float(result.metadata.score)
|
||||
score = float(result.metadata.score)
|
||||
) for result in search_result.objects
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class TextChunker():
|
|||
else:
|
||||
if len(self.paragraph_chunks) == 0:
|
||||
yield DocumentChunk(
|
||||
id = chunk_data["chunk_id"],
|
||||
id = str(chunk_data["chunk_id"]),
|
||||
text = chunk_data["text"],
|
||||
word_count = chunk_data["word_count"],
|
||||
is_part_of = self.document,
|
||||
|
|
@ -42,7 +42,7 @@ class TextChunker():
|
|||
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
|
||||
try:
|
||||
yield DocumentChunk(
|
||||
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
|
||||
text = chunk_text,
|
||||
word_count = self.chunk_size,
|
||||
is_part_of = self.document,
|
||||
|
|
@ -59,7 +59,7 @@ class TextChunker():
|
|||
if len(self.paragraph_chunks) > 0:
|
||||
try:
|
||||
yield DocumentChunk(
|
||||
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
|
||||
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
|
||||
word_count = self.chunk_size,
|
||||
is_part_of = self.document,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
from .generate_node_id import generate_node_id
|
||||
from .generate_node_name import generate_node_name
|
||||
from .generate_edge_name import generate_edge_name
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
def generate_node_name(name: str) -> str:
|
||||
return name.lower().replace("'", "")
|
||||
return name.lower().replace(" ", "_").replace("'", "")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from datetime import datetime, timezone
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules import data
|
||||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
|
||||
def get_graph_from_model(data_point: DataPoint, include_root = True):
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
|
@ -16,55 +17,29 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
|
|||
if isinstance(field_value, DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[str(edge_key)] = True
|
||||
property_nodes, property_edges = get_graph_from_model(field_value, True)
|
||||
nodes[:0] = property_nodes
|
||||
edges[:0] = property_edges
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}))
|
||||
added_edges[str(edge_key)] = True
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}))
|
||||
continue
|
||||
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
if isinstance(field_value, list):
|
||||
if isinstance(field_value[0], DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for item in field_value:
|
||||
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
|
||||
for item in field_value:
|
||||
property_nodes, property_edges = get_graph_from_model(item, True)
|
||||
nodes[:0] = property_nodes
|
||||
edges[:0] = property_edges
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[edge_key] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
|
|
@ -74,8 +49,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
|
|||
"type": "list"
|
||||
},
|
||||
}))
|
||||
added_edges[edge_key] = True
|
||||
continue
|
||||
continue
|
||||
|
||||
data_point_properties[field_name] = field_value
|
||||
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def prepare_edges(graph, source, target, edge_key):
|
|||
source: str(edge[0]),
|
||||
target: str(edge[1]),
|
||||
edge_key: str(edge[2]),
|
||||
} for edge in graph.edges(keys = True, data = True)]
|
||||
} for edge in graph.edges]
|
||||
|
||||
return pd.DataFrame(edge_list)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
from .extract_graph_from_data import extract_graph_from_data
|
||||
from .extract_graph_from_code import extract_graph_from_code
|
||||
from .query_graph_connections import query_graph_connections
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.engine.models import EntityType, Entity
|
||||
from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name
|
||||
from cognee.modules.engine.utils import generate_node_id, generate_node_name
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
|
||||
|
|
@ -95,7 +95,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
|
|||
for edge in graph.edges:
|
||||
source_node_id = generate_node_id(edge.source_node_id)
|
||||
target_node_id = generate_node_id(edge.target_node_id)
|
||||
relationship_name = generate_edge_name(edge.relationship_name)
|
||||
relationship_name = generate_node_name(edge.relationship_name)
|
||||
|
||||
edge_key = str(source_node_id) + str(target_node_id) + relationship_name
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
|
|||
target_node_id,
|
||||
edge.relationship_name,
|
||||
dict(
|
||||
relationship_name = generate_edge_name(edge.relationship_name),
|
||||
relationship_name = generate_node_name(edge.relationship_name),
|
||||
source_node_id = source_node_id,
|
||||
target_node_id = target_node_id,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
|
|||
else:
|
||||
vector_engine = get_vector_engine()
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("Entity_name", query_text = query, limit = 5),
|
||||
vector_engine.search("EntityType_name", query_text = query, limit = 5),
|
||||
vector_engine.search("Entity_text", query_text = query, limit = 5),
|
||||
vector_engine.search("EntityType_text", query_text = query, limit = 5),
|
||||
)
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class RelationshipModel(BaseModel):
|
||||
type: str
|
||||
source: str
|
||||
target: str
|
||||
|
||||
class NodeModel(BaseModel):
|
||||
node_id: str
|
||||
name: str
|
||||
default_relationship: Optional[RelationshipModel] = None
|
||||
children: List[Union[Dict[str, Any], "NodeModel"]] = Field(default_factory=list)
|
||||
|
||||
NodeModel.model_rebuild()
|
||||
|
||||
|
||||
class OntologyNode(BaseModel):
|
||||
id: str = Field(..., description = "Unique identifier made from node name.")
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class OntologyEdge(BaseModel):
|
||||
id: str
|
||||
source_id: str
|
||||
target_id: str
|
||||
relationship_type: str
|
||||
|
||||
class GraphOntology(BaseModel):
|
||||
nodes: list[OntologyNode]
|
||||
edges: list[OntologyEdge]
|
||||
|
|
@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
|
|||
added_data_points[str(new_point.id)] = True
|
||||
data_points.append(new_point)
|
||||
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
|
||||
for field_value_item in field_value:
|
||||
new_data_points = get_data_points_from_model(field_value_item, added_data_points)
|
||||
|
||||
|
|
@ -56,8 +56,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
|
|||
added_data_points[str(new_point.id)] = True
|
||||
data_points.append(new_point)
|
||||
|
||||
if (str(data_point.id) not in added_data_points):
|
||||
data_points.append(data_point)
|
||||
data_points.append(data_point)
|
||||
|
||||
return data_points
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ from cognee.modules.data.processing.document_types import Document
|
|||
|
||||
class TextSummary(DataPoint):
|
||||
text: str
|
||||
made_from: DocumentChunk
|
||||
chunk: DocumentChunk
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,12 +17,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
|
|||
|
||||
summaries = [
|
||||
TextSummary(
|
||||
id = uuid5(chunk.id, "TextSummary"),
|
||||
made_from = chunk,
|
||||
id = uuid5(chunk.id, "summary"),
|
||||
chunk = chunk,
|
||||
text = chunk_summaries[chunk_index].summary,
|
||||
) for (chunk_index, chunk) in enumerate(data_chunks)
|
||||
]
|
||||
|
||||
await add_data_points(summaries)
|
||||
add_data_points(summaries)
|
||||
|
||||
return data_chunks
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
random_node = (await vector_engine.search("Entity", "AI"))[0]
|
||||
random_node_name = random_node.payload["name"]
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
async def add_data_points(collection_name: str, data_points: list):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class Summary(BaseModel):
|
||||
id: UUID
|
||||
text: str
|
||||
chunk: "Chunk"
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
vector_index = ["text"]
|
||||
|
||||
class Chunk(BaseModel):
|
||||
id: UUID
|
||||
text: str
|
||||
summary: Summary
|
||||
document: "Document"
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
word_count: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
|
||||
vector_index = ["text"]
|
||||
|
||||
class Document(BaseModel):
|
||||
id: UUID
|
||||
chunks: list[Chunk]
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class EntityType(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
vector_index = ["name"]
|
||||
|
||||
class Entity(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
type: EntityType
|
||||
description: str
|
||||
chunks: list[Chunk]
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
vector_index = ["name"]
|
||||
|
||||
class OntologyModel(BaseModel):
|
||||
chunks: list[Chunk]
|
||||
|
|
@ -265,7 +265,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"id": "df16431d0f48b006",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -304,7 +304,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"id": "9086abf3af077ab4",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -349,7 +349,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 15,
|
||||
"id": "a9de0cc07f798b7f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -393,7 +393,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 16,
|
||||
"id": "185ff1c102d06111",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -437,7 +437,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 17,
|
||||
"id": "d55ce4c58f8efb67",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -479,7 +479,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 18,
|
||||
"id": "ca4ecc32721ad332",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -572,7 +572,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "9f1a1dbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -758,7 +758,7 @@
|
|||
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
|
||||
"\n",
|
||||
"vector_engine = get_vector_engine()\n",
|
||||
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
|
||||
"results = await search(vector_engine, \"entities\", \"sarah.nguyen@example.com\")\n",
|
||||
"for result in results:\n",
|
||||
" print(result)"
|
||||
]
|
||||
|
|
@ -788,8 +788,8 @@
|
|||
"source": [
|
||||
"from cognee.api.v1.search import SearchType\n",
|
||||
"\n",
|
||||
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node_name = node.payload[\"text\"]\n",
|
||||
"node = (await vector_engine.search(\"entities\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node_name = node.payload[\"name\"]\n",
|
||||
"\n",
|
||||
"search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n",
|
||||
"print(\"\\n\\Extracted summaries are:\\n\")\n",
|
||||
|
|
@ -881,7 +881,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.8"
|
||||
"version": "3.9.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
54
poetry.lock
generated
54
poetry.lock
generated
|
|
@ -3215,54 +3215,6 @@ docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
|||
embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "ibm-watsonx-ai (>=1.1.2)", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"]
|
||||
tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19,<=1.3.0)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.15"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "langchain_core-0.3.15-py3-none-any.whl", hash = "sha256:3d4ca6dbb8ed396a6ee061063832a2451b0ce8c345570f7b086ffa7288e4fa29"},
|
||||
{file = "langchain_core-0.3.15.tar.gz", hash = "sha256:b1a29787a4ffb7ec2103b4e97d435287201da7809b369740dd1e32f176325aba"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
jsonpatch = ">=1.33,<2.0"
|
||||
langsmith = ">=0.1.125,<0.2.0"
|
||||
packaging = ">=23.2,<25"
|
||||
pydantic = {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}
|
||||
PyYAML = ">=5.3"
|
||||
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0"
|
||||
typing-extensions = ">=4.7"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "0.3.2"
|
||||
description = "LangChain text splitting utilities"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "langchain_text_splitters-0.3.2-py3-none-any.whl", hash = "sha256:0db28c53f41d1bc024cdb3b1646741f6d46d5371e90f31e7e7c9fbe75d01c726"},
|
||||
{file = "langchain_text_splitters-0.3.2.tar.gz", hash = "sha256:81e6515d9901d6dd8e35fb31ccd4f30f76d44b771890c789dc835ef9f16204df"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
langchain-core = ">=0.3.15,<0.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "langdetect"
|
||||
version = "1.0.9"
|
||||
description = "Language detection library ported from Google's language-detection."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "langdetect-1.0.9-py2-none-any.whl", hash = "sha256:7cbc0746252f19e76f77c0b1690aadf01963be835ef0cd4b56dddf2a8f1dfc2a"},
|
||||
{file = "langdetect-1.0.9.tar.gz", hash = "sha256:cbc1fef89f8d062739774bd51eda3da3274006b3661d199c2655f6b3f6d605a0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = "*"
|
||||
|
||||
[[package]]
|
||||
name = "langfuse"
|
||||
version = "2.53.9"
|
||||
|
|
@ -5083,8 +5035,8 @@ argon2-cffi = {version = ">=23.1.0,<24", optional = true, markers = "extra == \"
|
|||
bcrypt = {version = ">=4.1.2,<5", optional = true, markers = "extra == \"bcrypt\""}
|
||||
|
||||
[package.extras]
|
||||
argon2 = ["argon2-cffi (>=23.1.0,<24)"]
|
||||
bcrypt = ["bcrypt (>=4.1.2,<5)"]
|
||||
argon2 = ["argon2-cffi (==23.1.0)"]
|
||||
bcrypt = ["bcrypt (==4.1.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
|
|
@ -7782,4 +7734,4 @@ weaviate = ["weaviate-client"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9.0,<3.12"
|
||||
content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236"
|
||||
content-hash = "bb70798562fee44c6daa2f5c7fa4d17165fb76016618c1fc8fd0782c5aa4a6de"
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ langsmith = "0.1.139"
|
|||
langdetect = "1.0.9"
|
||||
posthog = "^3.5.0"
|
||||
lancedb = "0.15.0"
|
||||
litellm = "1.49.1"
|
||||
litellm = "1.38.10"
|
||||
groq = "0.8.0"
|
||||
langfuse = "^2.32.0"
|
||||
pydantic-settings = "^2.2.1"
|
||||
|
|
|
|||
66
tools/daily_twitter_stats.py
Normal file
66
tools/daily_twitter_stats.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import tweepy
|
||||
import requests
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Twitter API credentials from GitHub Secrets
|
||||
API_KEY = '${{ secrets.TWITTER_API_KEY }}'
|
||||
API_SECRET = '${{ secrets.TWITTER_API_SECRET }}'
|
||||
ACCESS_TOKEN = '${{ secrets.TWITTER_ACCESS_TOKEN }}'
|
||||
ACCESS_SECRET = '${{ secrets.TWITTER_ACCESS_SECRET }}'
|
||||
USERNAME = '${{ secrets.TWITTER_USERNAME }}'
|
||||
SEGMENT_WRITE_KEY = '${{ secrets.SEGMENT_WRITE_KEY }}'
|
||||
|
||||
# Initialize Tweepy API
|
||||
auth = tweepy.OAuthHandler(API_KEY, API_SECRET)
|
||||
auth.set_access_token(ACCESS_TOKEN, ACCESS_SECRET)
|
||||
twitter_api = tweepy.API(auth)
|
||||
|
||||
# Segment endpoint
|
||||
SEGMENT_ENDPOINT = 'https://api.segment.io/v1/track'
|
||||
|
||||
|
||||
def get_follower_count(username):
|
||||
try:
|
||||
user = twitter_api.get_user(screen_name=username)
|
||||
return user.followers_count
|
||||
except tweepy.TweepError as e:
|
||||
print(f'Error fetching follower count: {e}')
|
||||
return None
|
||||
|
||||
|
||||
def send_data_to_segment(username, follower_count):
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
data = {
|
||||
'userId': username,
|
||||
'event': 'Follower Count Update',
|
||||
'properties': {
|
||||
'username': username,
|
||||
'follower_count': follower_count,
|
||||
'timestamp': current_time
|
||||
},
|
||||
'timestamp': current_time
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Basic {SEGMENT_WRITE_KEY.encode("utf-8").decode("utf-8")}'
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(SEGMENT_ENDPOINT, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code == 200:
|
||||
print(f'Successfully sent data to Segment for {username}')
|
||||
else:
|
||||
print(f'Failed to send data to Segment. Status code: {response.status_code}, Response: {response.text}')
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f'Error sending data to Segment: {e}')
|
||||
|
||||
|
||||
follower_count = get_follower_count(USERNAME)
|
||||
if follower_count is not None:
|
||||
send_data_to_segment(USERNAME, follower_count)
|
||||
else:
|
||||
print('Failed to retrieve follower count.')
|
||||
Loading…
Add table
Reference in a new issue