feat: add FalkorDB integration

This commit is contained in:
Boris Arzentar 2024-11-07 11:17:01 +01:00 committed by Leon Luithlen
parent 73372df31e
commit a2b1087c84
29 changed files with 180 additions and 427 deletions

View file

@ -8,6 +8,7 @@ from uuid import UUID
from neo4j import AsyncSession from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError from neo4j.exceptions import Neo4jError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("Neo4jAdapter") logger = logging.getLogger("Neo4jAdapter")
@ -62,10 +63,11 @@ class Neo4jAdapter(GraphDBInterface):
async def add_node(self, node: DataPoint): async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump()) serialized_properties = self.serialize_properties(node.model_dump())
query = dedent("""MERGE (node {id: $node_id}) query = """MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp() ON CREATE SET node += $properties
ON MATCH SET node += $properties, node.updated_at = timestamp() ON MATCH SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId""") ON MATCH SET node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId"""
params = { params = {
"node_id": str(node.id), "node_id": str(node.id),
@ -78,8 +80,9 @@ class Neo4jAdapter(GraphDBInterface):
query = """ query = """
UNWIND $nodes AS node UNWIND $nodes AS node
MERGE (n {id: node.node_id}) MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties, n.updated_at = timestamp() ON CREATE SET n += node.properties
ON MATCH SET n += node.properties, n.updated_at = timestamp() ON MATCH SET n += node.properties
ON MATCH SET n.updated_at = timestamp()
WITH n, node.node_id AS label WITH n, node.node_id AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
@ -134,9 +137,8 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params) return await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
query = """ query = f"""
MATCH (from_node)-[relationship]->(to_node) MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
RETURN COUNT(relationship) > 0 AS edge_exists RETURN COUNT(relationship) > 0 AS edge_exists
""" """
@ -176,18 +178,17 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties) serialized_properties = self.serialize_properties(edge_properties)
query = dedent("""MATCH (from_node {id: $from_node}), query = f"""MATCH (from_node:`{str(from_node)}`
(to_node {id: $to_node}) {{id: $from_node}}),
MERGE (from_node)-[r]->(to_node) (to_node:`{str(to_node)}` {{id: $to_node}})
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
ON MATCH SET r += $properties, r.updated_at = timestamp() ON CREATE SET r += $properties, r.updated_at = timestamp()
RETURN r ON MATCH SET r += $properties, r.updated_at = timestamp()
""") RETURN r"""
params = { params = {
"from_node": str(from_node), "from_node": str(from_node),
"to_node": str(to_node), "to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties "properties": serialized_properties
} }

View file

@ -30,10 +30,6 @@ class NetworkXAdapter(GraphDBInterface):
def __init__(self, filename = "cognee_graph.pkl"): def __init__(self, filename = "cognee_graph.pkl"):
self.filename = filename self.filename = filename
async def get_graph_data(self):
await self.load_graph_from_file()
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
async def query(self, query: str, params: dict): async def query(self, query: str, params: dict):
pass pass

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
from textwrap import dedent from textwrap import dedent
from typing import Any from typing import Any
from uuid import UUID
from falkordb import FalkorDB from falkordb import FalkorDB
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -162,35 +161,6 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
async def extract_nodes(self, data_point_ids: list[str]): async def extract_nodes(self, data_point_ids: list[str]):
return await self.retrieve(data_point_ids) return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
return connections
async def search( async def search(
self, self,
collection_name: str, collection_name: str,

View file

@ -1,113 +0,0 @@
import asyncio
from falkordb import FalkorDB
from ..models.DataPoint import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
class FalcorDBAdapter(VectorDBInterface):
def __init__(
self,
graph_database_url: str,
graph_database_port: int,
embedding_engine = EmbeddingEngine,
):
self.driver = FalkorDB(
host = graph_database_url,
port = graph_database_port)
self.embedding_engine = embedding_engine
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
collections = self.driver.list_graphs()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema = None):
self.driver.select_graph(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
graph = self.driver.select_graph(collection_name)
def stringify_properties(properties: dict) -> str:
return ",".join(f"{key}:'{value}'" for key, value in properties.items())
def create_data_point_query(data_point: DataPoint):
node_label = type(data_point.payload).__name__
node_properties = stringify_properties(data_point.payload.dict())
return f"""CREATE (:{node_label} {{{node_properties}}})"""
query = " ".join([create_data_point_query(data_point) for data_point in data_points])
graph.query(query)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
graph = self.driver.select_graph(collection_name)
return graph.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{
"node_ids": data_point_ids,
},
)
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: list[float] = None,
limit: int = 10,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
graph = self.driver.select_graph(collection_name)
query = f"""
CALL db.idx.vector.queryNodes(
null,
'text',
{limit},
{query_vector}
) YIELD node, score
"""
result = graph.query(query)
return result
async def batch_search(
self,
collection_name: str,
query_texts: list[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
graph = self.driver.select_graph(collection_name)
return graph.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{
"node_ids": data_point_ids,
},
)

View file

@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data from .serialize_datetime import serialize_datetime
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
@ -79,10 +79,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points( async def create_data_points(
self, collection_name: str, data_points: List[DataPoint] self, collection_name: str, data_points: List[DataPoint]
): ):
if not await self.has_collection(collection_name): async with self.get_async_session() as session:
await self.create_collection( if not await self.has_collection(collection_name):
collection_name = collection_name, await self.create_collection(
payload_schema = type(data_points[0]), collection_name=collection_name,
payload_schema=type(data_points[0]),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
) )
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
@ -102,10 +107,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
payload = Column(JSON) payload = Column(JSON)
vector = Column(Vector(vector_size)) vector = Column(Vector(vector_size))
def __init__(self, id, payload, vector): pgvector_data_points = [
self.id = id PGVectorDataPoint(
self.payload = payload id=data_point.id,
self.vector = vector vector=data_vectors[data_index],
payload=serialize_datetime(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
pgvector_data_points = [ pgvector_data_points = [
PGVectorDataPoint( PGVectorDataPoint(
@ -127,7 +136,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [ await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema( IndexSchema(
id = data_point.id, id = data_point.id,
text = data_point.get_embeddable_data(), text = getattr(data_point, data_point._metadata["index_fields"][0]),
) for data_point in data_points ) for data_point in data_points
]) ])
@ -197,19 +206,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_list = [] vector_list = []
# Extract distances and find min/max for normalization # Create and return ScoredResult objects
for vector in closest_items: return [
# TODO: Add normalization of similarity score ScoredResult(
vector_list.append(vector) id = UUID(row.id),
payload = row.payload,
# Create and return ScoredResult objects score = row.similarity
return [ ) for row in vector_list
ScoredResult( ]
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def batch_search( async def batch_search(
self, self,

View file

@ -3,7 +3,6 @@ from uuid import UUID
from typing import List, Dict, Optional from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine

View file

@ -11,6 +11,7 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter") logger = logging.getLogger("WeaviateAdapter")
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
uuid: str
text: str text: str
_metadata: dict = { _metadata: dict = {
@ -88,10 +89,8 @@ class WeaviateAdapter(VectorDBInterface):
def convert_to_weaviate_data_points(data_point: DataPoint): def convert_to_weaviate_data_points(data_point: DataPoint):
vector = data_vectors[data_points.index(data_point)] vector = data_vectors[data_points.index(data_point)]
properties = data_point.model_dump() properties = data_point.model_dump()
properties["uuid"] = properties["id"]
if "id" in properties: del properties["id"]
properties["uuid"] = str(data_point.id)
del properties["id"]
return DataObject( return DataObject(
uuid = data_point.id, uuid = data_point.id,
@ -131,8 +130,8 @@ class WeaviateAdapter(VectorDBInterface):
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]): async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [ await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema( IndexSchema(
id = data_point.id, uuid = str(data_point.id),
text = data_point.get_embeddable_data(), text = getattr(data_point, data_point._metadata["index_fields"][0]),
) for data_point in data_points ) for data_point in data_points
]) ])
@ -179,9 +178,9 @@ class WeaviateAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id = UUID(str(result.uuid)), id = UUID(result.id),
payload = result.properties, payload = result.properties,
score = 1 - float(result.metadata.score) score = float(result.metadata.score)
) for result in search_result.objects ) for result in search_result.objects
] ]

View file

@ -29,7 +29,7 @@ class TextChunker():
else: else:
if len(self.paragraph_chunks) == 0: if len(self.paragraph_chunks) == 0:
yield DocumentChunk( yield DocumentChunk(
id = chunk_data["chunk_id"], id = str(chunk_data["chunk_id"]),
text = chunk_data["text"], text = chunk_data["text"],
word_count = chunk_data["word_count"], word_count = chunk_data["word_count"],
is_part_of = self.document, is_part_of = self.document,
@ -42,7 +42,7 @@ class TextChunker():
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks) chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
try: try:
yield DocumentChunk( yield DocumentChunk(
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
text = chunk_text, text = chunk_text,
word_count = self.chunk_size, word_count = self.chunk_size,
is_part_of = self.document, is_part_of = self.document,
@ -59,7 +59,7 @@ class TextChunker():
if len(self.paragraph_chunks) > 0: if len(self.paragraph_chunks) > 0:
try: try:
yield DocumentChunk( yield DocumentChunk(
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks), text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
word_count = self.chunk_size, word_count = self.chunk_size,
is_part_of = self.document, is_part_of = self.document,

View file

@ -1,3 +1,2 @@
from .generate_node_id import generate_node_id from .generate_node_id import generate_node_id
from .generate_node_name import generate_node_name from .generate_node_name import generate_node_name
from .generate_edge_name import generate_edge_name

View file

@ -1,2 +1,2 @@
def generate_node_name(name: str) -> str: def generate_node_name(name: str) -> str:
return name.lower().replace("'", "") return name.lower().replace(" ", "_").replace("'", "")

View file

@ -1,8 +1,9 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules import data
from cognee.modules.storage.utils import copy_model from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): def get_graph_from_model(data_point: DataPoint, include_root = True):
nodes = [] nodes = []
edges = [] edges = []
@ -16,55 +17,29 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
if isinstance(field_value, DataPoint): if isinstance(field_value, DataPoint):
excluded_properties.add(field_name) excluded_properties.add(field_name)
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) property_nodes, property_edges = get_graph_from_model(field_value, True)
nodes[:0] = property_nodes
for node in property_nodes: edges[:0] = property_edges
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges): for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
if str(edge_key) not in added_edges: "target_node_id": property_node.id,
edges.append((data_point.id, property_node.id, field_name, { "relationship_name": field_name,
"source_node_id": data_point.id, "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"target_node_id": property_node.id, }))
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue continue
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): if isinstance(field_value, list):
excluded_properties.add(field_name) if isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
for item in field_value: for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) property_nodes, property_edges = get_graph_from_model(item, True)
nodes[:0] = property_nodes
edges[:0] = property_edges
for node in property_nodes: for property_node in get_own_properties(property_nodes, property_edges):
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, { edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id, "source_node_id": data_point.id,
"target_node_id": property_node.id, "target_node_id": property_node.id,
@ -74,8 +49,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
"type": "list" "type": "list"
}, },
})) }))
added_edges[edge_key] = True continue
continue
data_point_properties[field_name] = field_value data_point_properties[field_name] = field_value

View file

@ -115,7 +115,7 @@ def prepare_edges(graph, source, target, edge_key):
source: str(edge[0]), source: str(edge[0]),
target: str(edge[1]), target: str(edge[1]),
edge_key: str(edge[2]), edge_key: str(edge[2]),
} for edge in graph.edges(keys = True, data = True)] } for edge in graph.edges]
return pd.DataFrame(edge_list) return pd.DataFrame(edge_list)

View file

@ -1,3 +1,2 @@
from .extract_graph_from_data import extract_graph_from_data from .extract_graph_from_data import extract_graph_from_data
from .extract_graph_from_code import extract_graph_from_code
from .query_graph_connections import query_graph_connections from .query_graph_connections import query_graph_connections

View file

@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.models import EntityType, Entity from cognee.modules.engine.models import EntityType, Entity
from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name from cognee.modules.engine.utils import generate_node_id, generate_node_name
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
@ -95,7 +95,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
for edge in graph.edges: for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id) source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id) target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_edge_name(edge.relationship_name) relationship_name = generate_node_name(edge.relationship_name)
edge_key = str(source_node_id) + str(target_node_id) + relationship_name edge_key = str(source_node_id) + str(target_node_id) + relationship_name
@ -105,7 +105,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
target_node_id, target_node_id,
edge.relationship_name, edge.relationship_name,
dict( dict(
relationship_name = generate_edge_name(edge.relationship_name), relationship_name = generate_node_name(edge.relationship_name),
source_node_id = source_node_id, source_node_id = source_node_id,
target_node_id = target_node_id, target_node_id = target_node_id,
), ),

View file

@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
else: else:
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
results = await asyncio.gather( results = await asyncio.gather(
vector_engine.search("Entity_name", query_text = query, limit = 5), vector_engine.search("Entity_text", query_text = query, limit = 5),
vector_engine.search("EntityType_name", query_text = query, limit = 5), vector_engine.search("EntityType_text", query_text = query, limit = 5),
) )
results = [*results[0], *results[1]] results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5] relevant_results = [result for result in results if result.score < 0.5][:5]

View file

@ -1,31 +0,0 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
class RelationshipModel(BaseModel):
type: str
source: str
target: str
class NodeModel(BaseModel):
node_id: str
name: str
default_relationship: Optional[RelationshipModel] = None
children: List[Union[Dict[str, Any], "NodeModel"]] = Field(default_factory=list)
NodeModel.model_rebuild()
class OntologyNode(BaseModel):
id: str = Field(..., description = "Unique identifier made from node name.")
name: str
description: str
class OntologyEdge(BaseModel):
id: str
source_id: str
target_id: str
relationship_type: str
class GraphOntology(BaseModel):
nodes: list[OntologyNode]
edges: list[OntologyEdge]

View file

@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
added_data_points[str(new_point.id)] = True added_data_points[str(new_point.id)] = True
data_points.append(new_point) data_points.append(new_point)
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
for field_value_item in field_value: for field_value_item in field_value:
new_data_points = get_data_points_from_model(field_value_item, added_data_points) new_data_points = get_data_points_from_model(field_value_item, added_data_points)
@ -56,8 +56,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
added_data_points[str(new_point.id)] = True added_data_points[str(new_point.id)] = True
data_points.append(new_point) data_points.append(new_point)
if (str(data_point.id) not in added_data_points): data_points.append(data_point)
data_points.append(data_point)
return data_points return data_points

View file

@ -4,8 +4,9 @@ from cognee.modules.data.processing.document_types import Document
class TextSummary(DataPoint): class TextSummary(DataPoint):
text: str text: str
made_from: DocumentChunk chunk: DocumentChunk
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
} }

View file

@ -17,12 +17,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
summaries = [ summaries = [
TextSummary( TextSummary(
id = uuid5(chunk.id, "TextSummary"), id = uuid5(chunk.id, "summary"),
made_from = chunk, chunk = chunk,
text = chunk_summaries[chunk_index].summary, text = chunk_summaries[chunk_index].summary,
) for (chunk_index, chunk) in enumerate(data_chunks) ) for (chunk_index, chunk) in enumerate(data_chunks)
] ]
await add_data_points(summaries) add_data_points(summaries)
return data_chunks return data_chunks

View file

@ -32,8 +32,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0] random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["name"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -36,8 +36,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["name"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -65,8 +65,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["name"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -37,8 +37,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["name"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -35,8 +35,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["name"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -1,62 +0,0 @@
from typing import Optional
from uuid import UUID
from datetime import datetime
from pydantic import BaseModel
async def add_data_points(collection_name: str, data_points: list):
pass
class Summary(BaseModel):
id: UUID
text: str
chunk: "Chunk"
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["text"]
class Chunk(BaseModel):
id: UUID
text: str
summary: Summary
document: "Document"
created_at: datetime
updated_at: Optional[datetime]
word_count: int
chunk_index: int
cut_type: str
vector_index = ["text"]
class Document(BaseModel):
id: UUID
chunks: list[Chunk]
created_at: datetime
updated_at: Optional[datetime]
class EntityType(BaseModel):
id: UUID
name: str
description: str
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["name"]
class Entity(BaseModel):
id: UUID
name: str
type: EntityType
description: str
chunks: list[Chunk]
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["name"]
class OntologyModel(BaseModel):
chunks: list[Chunk]

View file

@ -265,7 +265,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 13,
"id": "df16431d0f48b006", "id": "df16431d0f48b006",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -304,7 +304,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 14,
"id": "9086abf3af077ab4", "id": "9086abf3af077ab4",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -349,7 +349,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 15,
"id": "a9de0cc07f798b7f", "id": "a9de0cc07f798b7f",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -393,7 +393,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 16,
"id": "185ff1c102d06111", "id": "185ff1c102d06111",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -437,7 +437,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 17,
"id": "d55ce4c58f8efb67", "id": "d55ce4c58f8efb67",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -479,7 +479,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 18,
"id": "ca4ecc32721ad332", "id": "ca4ecc32721ad332",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -572,7 +572,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 26,
"id": "9f1a1dbd", "id": "9f1a1dbd",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -758,7 +758,7 @@
"from cognee.infrastructure.databases.vector import get_vector_engine\n", "from cognee.infrastructure.databases.vector import get_vector_engine\n",
"\n", "\n",
"vector_engine = get_vector_engine()\n", "vector_engine = get_vector_engine()\n",
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n", "results = await search(vector_engine, \"entities\", \"sarah.nguyen@example.com\")\n",
"for result in results:\n", "for result in results:\n",
" print(result)" " print(result)"
] ]
@ -788,8 +788,8 @@
"source": [ "source": [
"from cognee.api.v1.search import SearchType\n", "from cognee.api.v1.search import SearchType\n",
"\n", "\n",
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n", "node = (await vector_engine.search(\"entities\", \"sarah.nguyen@example.com\"))[0]\n",
"node_name = node.payload[\"text\"]\n", "node_name = node.payload[\"name\"]\n",
"\n", "\n",
"search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n", "search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n",
"print(\"\\n\\Extracted summaries are:\\n\")\n", "print(\"\\n\\Extracted summaries are:\\n\")\n",
@ -881,7 +881,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.8" "version": "3.9.6"
} }
}, },
"nbformat": 4, "nbformat": 4,

54
poetry.lock generated
View file

@ -3215,54 +3215,6 @@ docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "ibm-watsonx-ai (>=1.1.2)", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "ibm-watsonx-ai (>=1.1.2)", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"]
tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19,<=1.3.0)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19,<=1.3.0)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"]
[[package]]
name = "langchain-core"
version = "0.3.15"
description = "Building applications with LLMs through composability"
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langchain_core-0.3.15-py3-none-any.whl", hash = "sha256:3d4ca6dbb8ed396a6ee061063832a2451b0ce8c345570f7b086ffa7288e4fa29"},
{file = "langchain_core-0.3.15.tar.gz", hash = "sha256:b1a29787a4ffb7ec2103b4e97d435287201da7809b369740dd1e32f176325aba"},
]
[package.dependencies]
jsonpatch = ">=1.33,<2.0"
langsmith = ">=0.1.125,<0.2.0"
packaging = ">=23.2,<25"
pydantic = {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}
PyYAML = ">=5.3"
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0"
typing-extensions = ">=4.7"
[[package]]
name = "langchain-text-splitters"
version = "0.3.2"
description = "LangChain text splitting utilities"
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langchain_text_splitters-0.3.2-py3-none-any.whl", hash = "sha256:0db28c53f41d1bc024cdb3b1646741f6d46d5371e90f31e7e7c9fbe75d01c726"},
{file = "langchain_text_splitters-0.3.2.tar.gz", hash = "sha256:81e6515d9901d6dd8e35fb31ccd4f30f76d44b771890c789dc835ef9f16204df"},
]
[package.dependencies]
langchain-core = ">=0.3.15,<0.4.0"
[[package]]
name = "langdetect"
version = "1.0.9"
description = "Language detection library ported from Google's language-detection."
optional = false
python-versions = "*"
files = [
{file = "langdetect-1.0.9-py2-none-any.whl", hash = "sha256:7cbc0746252f19e76f77c0b1690aadf01963be835ef0cd4b56dddf2a8f1dfc2a"},
{file = "langdetect-1.0.9.tar.gz", hash = "sha256:cbc1fef89f8d062739774bd51eda3da3274006b3661d199c2655f6b3f6d605a0"},
]
[package.dependencies]
six = "*"
[[package]] [[package]]
name = "langfuse" name = "langfuse"
version = "2.53.9" version = "2.53.9"
@ -5083,8 +5035,8 @@ argon2-cffi = {version = ">=23.1.0,<24", optional = true, markers = "extra == \"
bcrypt = {version = ">=4.1.2,<5", optional = true, markers = "extra == \"bcrypt\""} bcrypt = {version = ">=4.1.2,<5", optional = true, markers = "extra == \"bcrypt\""}
[package.extras] [package.extras]
argon2 = ["argon2-cffi (>=23.1.0,<24)"] argon2 = ["argon2-cffi (==23.1.0)"]
bcrypt = ["bcrypt (>=4.1.2,<5)"] bcrypt = ["bcrypt (==4.1.2)"]
[[package]] [[package]]
name = "pyarrow" name = "pyarrow"
@ -7782,4 +7734,4 @@ weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9.0,<3.12" python-versions = ">=3.9.0,<3.12"
content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236" content-hash = "bb70798562fee44c6daa2f5c7fa4d17165fb76016618c1fc8fd0782c5aa4a6de"

View file

@ -59,7 +59,7 @@ langsmith = "0.1.139"
langdetect = "1.0.9" langdetect = "1.0.9"
posthog = "^3.5.0" posthog = "^3.5.0"
lancedb = "0.15.0" lancedb = "0.15.0"
litellm = "1.49.1" litellm = "1.38.10"
groq = "0.8.0" groq = "0.8.0"
langfuse = "^2.32.0" langfuse = "^2.32.0"
pydantic-settings = "^2.2.1" pydantic-settings = "^2.2.1"

View file

@ -0,0 +1,66 @@
import tweepy
import requests
import json
from datetime import datetime, timezone
# Twitter API credentials from GitHub Secrets
API_KEY = '${{ secrets.TWITTER_API_KEY }}'
API_SECRET = '${{ secrets.TWITTER_API_SECRET }}'
ACCESS_TOKEN = '${{ secrets.TWITTER_ACCESS_TOKEN }}'
ACCESS_SECRET = '${{ secrets.TWITTER_ACCESS_SECRET }}'
USERNAME = '${{ secrets.TWITTER_USERNAME }}'
SEGMENT_WRITE_KEY = '${{ secrets.SEGMENT_WRITE_KEY }}'
# Initialize Tweepy API
auth = tweepy.OAuthHandler(API_KEY, API_SECRET)
auth.set_access_token(ACCESS_TOKEN, ACCESS_SECRET)
twitter_api = tweepy.API(auth)
# Segment endpoint
SEGMENT_ENDPOINT = 'https://api.segment.io/v1/track'
def get_follower_count(username):
try:
user = twitter_api.get_user(screen_name=username)
return user.followers_count
except tweepy.TweepError as e:
print(f'Error fetching follower count: {e}')
return None
def send_data_to_segment(username, follower_count):
current_time = datetime.now(timezone.utc).isoformat()
data = {
'userId': username,
'event': 'Follower Count Update',
'properties': {
'username': username,
'follower_count': follower_count,
'timestamp': current_time
},
'timestamp': current_time
}
headers = {
'Content-Type': 'application/json',
'Authorization': f'Basic {SEGMENT_WRITE_KEY.encode("utf-8").decode("utf-8")}'
}
try:
response = requests.post(SEGMENT_ENDPOINT, headers=headers, data=json.dumps(data))
if response.status_code == 200:
print(f'Successfully sent data to Segment for {username}')
else:
print(f'Failed to send data to Segment. Status code: {response.status_code}, Response: {response.text}')
except requests.exceptions.RequestException as e:
print(f'Error sending data to Segment: {e}')
follower_count = get_follower_count(USERNAME)
if follower_count is not None:
send_data_to_segment(USERNAME, follower_count)
else:
print('Failed to retrieve follower count.')