feat: code graph swe integration

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Co-authored-by: hande-k <handekafkas7@gmail.com>
Co-authored-by: Igor Ilic <igorilic03@gmail.com>
Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
Boris 2024-11-27 09:32:29 +01:00 committed by GitHub
parent 0fb47ba23d
commit 64b8aac86f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
57 changed files with 1494 additions and 478 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

View file

@ -0,0 +1,63 @@
name: test | multimedia notebook
on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_notebook_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: |
poetry install --no-interaction
poetry add jupyter --no-interaction
- name: Execute Jupyter Notebook
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: |
poetry run jupyter nbconvert \
--to notebook \
--execute notebooks/cognee_multimedia_demo.ipynb \
--output executed_notebook.ipynb \
--ExecutePreprocessor.timeout=1200

View file

@ -13,6 +13,7 @@ concurrency:
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs: jobs:
get_docs_changes: get_docs_changes:
@ -56,12 +57,6 @@ jobs:
- name: Run integration tests - name: Run integration tests
run: poetry run pytest cognee/tests/integration/ run: poetry run pytest cognee/tests/integration/
- name: Run convert_graph_from_code_graph test
run: poetry run pytest cognee/tests/tasks/graph/convert_graph_from_code_graph_test.py
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Run default basic pipeline - name: Run default basic pipeline
env: env:
ENV: 'dev' ENV: 'dev'

View file

@ -13,6 +13,7 @@ concurrency:
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs: jobs:
get_docs_changes: get_docs_changes:
@ -56,12 +57,6 @@ jobs:
- name: Run integration tests - name: Run integration tests
run: poetry run pytest cognee/tests/integration/ run: poetry run pytest cognee/tests/integration/
- name: Run convert_graph_from_code_graph test
run: poetry run pytest cognee/tests/tasks/graph/convert_graph_from_code_graph_test.py
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Run default basic pipeline - name: Run default basic pipeline
env: env:
ENV: 'dev' ENV: 'dev'

View file

@ -13,6 +13,7 @@ concurrency:
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs: jobs:
get_docs_changes: get_docs_changes:
@ -56,12 +57,6 @@ jobs:
- name: Run integration tests - name: Run integration tests
run: poetry run pytest cognee/tests/integration/ run: poetry run pytest cognee/tests/integration/
- name: Run convert_graph_from_code_graph test
run: poetry run pytest cognee/tests/tasks/graph/convert_graph_from_code_graph_test.py
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Run default basic pipeline - name: Run default basic pipeline
env: env:
ENV: 'dev' ENV: 'dev'

View file

@ -105,37 +105,65 @@ import asyncio
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
async def main(): async def main():
# Reset cognee data # Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data() await cognee.prune.prune_data()
# Reset cognee system state
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text
text = """ text = """
Natural language processing (NLP) is an interdisciplinary Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval. subfield of computer science and information retrieval.
""" """
# Add text to cognee print("Adding text to cognee:")
print(text.strip())
# Add the text, and make it available for cognify
await cognee.add(text) await cognee.add(text)
print("Text added successfully.\n")
print("Running cognify to create knowledge graph...\n")
print("Cognify process steps:")
print("1. Classifying the document: Determining the type and category of the input text.")
print("2. Checking permissions: Ensuring the user has the necessary rights to process the text.")
print("3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis.")
print("4. Adding data points: Storing the extracted chunks for processing.")
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
# Use LLMs and cognee to create knowledge graph # Use LLMs and cognee to create knowledge graph
await cognee.cognify() await cognee.cognify()
print("Cognify process complete.\n")
# Search cognee for insights
query_text = 'Tell me about NLP'
print(f"Searching cognee for insights with query: '{query_text}'")
# Query cognee for insights on the added text
search_results = await cognee.search( search_results = await cognee.search(
SearchType.INSIGHTS, SearchType.INSIGHTS, query_text=query_text
"Tell me about NLP",
) )
print("Search results:")
# Display results # Display results
for result_text in search_results: for result_text in search_results:
print(result_text) print(result_text)
# natural_language_processing is_a field
# natural_language_processing is_subfield_of computer_science
# natural_language_processing is_subfield_of information_retrieval
asyncio.run(main()) # Example output:
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
# (...)
#
# It represents nodes and relationships in the knowledge graph:
# - The first element is the source node (e.g., 'natural language processing').
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
# - The third element is the target node (e.g., 'computer science').
if __name__ == '__main__':
asyncio.run(main())
``` ```
When you run this script, you will see step-by-step messages in the console that help you trace the execution flow and understand what the script is doing at each stage.
A version of this example is here: `examples/python/simple_example.py` A version of this example is here: `examples/python/simple_example.py`
### Create your own memory store ### Create your own memory store

View file

@ -4,7 +4,7 @@ from .config import get_graph_config
from .graph_db_interface import GraphDBInterface from .graph_db_interface import GraphDBInterface
async def get_graph_engine() -> GraphDBInterface : async def get_graph_engine() -> GraphDBInterface:
"""Factory function to get the appropriate graph client based on the graph type.""" """Factory function to get the appropriate graph client based on the graph type."""
config = get_graph_config() config = get_graph_config()

View file

@ -2,7 +2,7 @@
import logging import logging
import asyncio import asyncio
from textwrap import dedent from textwrap import dedent
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict, Union
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from uuid import UUID from uuid import UUID
from neo4j import AsyncSession from neo4j import AsyncSession
@ -432,3 +432,49 @@ class Neo4jAdapter(GraphDBInterface):
) for record in result] ) for record in result]
return (nodes, edges) return (nodes, edges)
async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.
Args:
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
Example: [{"community": ["1", "2"]}]
Returns:
tuple: A tuple containing two lists: nodes and edges.
"""
where_clauses = []
for attribute, values in attribute_filters[0].items():
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
where_clauses.append(f"n.{attribute} IN [{values_str}]")
where_clause = " AND ".join(where_clauses)
query_nodes = f"""
MATCH (n)
WHERE {where_clause}
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
"""
result_nodes = await self.query(query_nodes)
nodes = [(
record["id"],
record["properties"],
) for record in result_nodes]
query_edges = f"""
MATCH (n)-[r]->(m)
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')}
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result_edges = await self.query(query_edges)
edges = [(
record["source"],
record["target"],
record["type"],
record["properties"],
) for record in result_edges]
return (nodes, edges)

View file

@ -6,7 +6,7 @@ import json
import asyncio import asyncio
import logging import logging
from re import A from re import A
from typing import Dict, Any, List from typing import Dict, Any, List, Union
from uuid import UUID from uuid import UUID
import aiofiles import aiofiles
import aiofiles.os as aiofiles_os import aiofiles.os as aiofiles_os
@ -301,3 +301,39 @@ class NetworkXAdapter(GraphDBInterface):
logger.info("Graph deleted successfully.") logger.info("Graph deleted successfully.")
except Exception as error: except Exception as error:
logger.error("Failed to delete graph: %s", error) logger.error("Failed to delete graph: %s", error)
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
"""
Fetches nodes and relationships filtered by specified attribute values.
Args:
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
Example: [{"community": ["1", "2"]}]
Returns:
tuple: A tuple containing two lists:
- Nodes: List of tuples (node_id, node_properties).
- Edges: List of tuples (source_id, target_id, relationship_type, edge_properties).
"""
# Create filters for nodes based on the attribute filters
where_clauses = []
for attribute, values in attribute_filters[0].items():
where_clauses.append((attribute, values))
# Filter nodes
filtered_nodes = [
(node, data) for node, data in self.graph.nodes(data=True)
if all(data.get(attr) in values for attr, values in where_clauses)
]
# Filter edges where both source and target nodes satisfy the filters
filtered_edges = [
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
for source, target, data in self.graph.edges(data=True)
if (
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
)
]
return filtered_nodes, filtered_edges

View file

@ -10,6 +10,7 @@ from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
@ -141,6 +142,34 @@ class LanceDBAdapter(VectorDBInterface):
score = 0, score = 0,
) for result in results.to_dict("index").values()] ) for result in results.to_dict("index").values()]
async def get_distances_of_collection(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.vector_search(query_vector).to_pandas()
result_values = list(results.to_dict("index").values())
normalized_values = normalize_distances(result_values)
return [ScoredResult(
id=UUID(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
) for value_index, result in enumerate(result_values)]
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
@ -148,6 +177,7 @@ class LanceDBAdapter(VectorDBInterface):
query_vector: List[float] = None, query_vector: List[float] = None,
limit: int = 5, limit: int = 5,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise ValueError("One of query_text or query_vector must be provided!")
@ -162,26 +192,7 @@ class LanceDBAdapter(VectorDBInterface):
result_values = list(results.to_dict("index").values()) result_values = list(results.to_dict("index").values())
min_value = 100 normalized_values = normalize_distances(result_values)
max_value = 0
for result in result_values:
value = float(result["_distance"])
if value > max_value:
max_value = value
if value < min_value:
min_value = value
normalized_values = []
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
if max_value == min_value:
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
normalized_values = [0 for _ in result_values]
else:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return [ScoredResult( return [ScoredResult(
id = UUID(result["id"]), id = UUID(result["id"]),

View file

@ -11,6 +11,7 @@ from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data from .serialize_data import serialize_data
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base from ...relational.ModelBase import Base
@ -22,6 +23,19 @@ class IndexSchema(DataPoint):
"index_fields": ["text"] "index_fields": ["text"]
} }
def singleton(class_):
# Note: Using this singleton as a decorator to a class removes
# the option to use class methods for that class
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
@singleton
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__( def __init__(
@ -162,6 +176,53 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) for result in results ) for result in results
] ]
async def get_distances_of_collection(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
)-> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
)
.order_by("similarity")
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def search( async def search(
self, self,
collection_name: str, collection_name: str,

View file

@ -10,5 +10,3 @@ async def create_db_and_tables():
await vector_engine.create_database() await vector_engine.create_database()
async with vector_engine.engine.begin() as connection: async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

View file

@ -0,0 +1,26 @@
from typing import List
def normalize_distances(result_values: List[dict]) -> List[float]:
min_value = 100
max_value = 0
for result in result_values:
value = float(result["_distance"])
if value > max_value:
max_value = value
if value < min_value:
min_value = value
normalized_values = []
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
if max_value == min_value:
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
normalized_values = [0 for _ in result_values]
else:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return normalized_values

View file

@ -11,6 +11,7 @@ class DataPoint(BaseModel):
__tablename__ = "data_point" __tablename__ = "data_point"
id: UUID = Field(default_factory = uuid4) id: UUID = Field(default_factory = uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc) updated_at: Optional[datetime] = datetime.now(timezone.utc)
topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = { _metadata: Optional[MetaData] = {
"index_fields": [] "index_fields": []
} }

View file

@ -87,6 +87,9 @@ class OpenAIAdapter(LLMInterface):
transcription = litellm.transcription( transcription = litellm.transcription(
model = self.transcription_model, model = self.transcription_model,
file = Path(input), file = Path(input),
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_retries = 5, max_retries = 5,
) )
@ -112,6 +115,9 @@ class OpenAIAdapter(LLMInterface):
}, },
], ],
}], }],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_tokens = 300, max_tokens = 300,
max_retries = 5, max_retries = 5,
) )

View file

@ -1,9 +1,12 @@
from typing import List, Dict, Union import numpy as np
from typing import List, Dict, Union
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
from cognee.infrastructure.databases.graph import get_graph_engine import heapq
from graphistry import edges
class CogneeGraph(CogneeAbstractGraph): class CogneeGraph(CogneeAbstractGraph):
""" """
@ -39,26 +42,33 @@ class CogneeGraph(CogneeAbstractGraph):
def get_node(self, node_id: str) -> Node: def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None) return self.nodes.get(node_id, None)
def get_edges(self, node_id: str) -> List[Edge]: def get_edges_of_node(self, node_id: str) -> List[Edge]:
node = self.get_node(node_id) node = self.get_node(node_id)
if node: if node:
return node.skeleton_edges return node.skeleton_edges
else: else:
raise ValueError(f"Node with id {node_id} does not exist.") raise ValueError(f"Node with id {node_id} does not exist.")
def get_edges(self)-> List[Edge]:
return edges
async def project_graph_from_db(self, async def project_graph_from_db(self,
adapter: Union[GraphDBInterface], adapter: Union[GraphDBInterface],
node_properties_to_project: List[str], node_properties_to_project: List[str],
edge_properties_to_project: List[str], edge_properties_to_project: List[str],
directed = True, directed = True,
node_dimension = 1, node_dimension = 1,
edge_dimension = 1) -> None: edge_dimension = 1,
memory_fragment_filter = []) -> None:
if node_dimension < 1 or edge_dimension < 1: if node_dimension < 1 or edge_dimension < 1:
raise ValueError("Dimensions must be positive integers") raise ValueError("Dimensions must be positive integers")
try: try:
nodes_data, edges_data = await adapter.get_graph_data() if len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
if not nodes_data: if not nodes_data:
raise ValueError("No node data retrieved from the database.") raise ValueError("No node data retrieved from the database.")
@ -89,3 +99,81 @@ class CogneeGraph(CogneeAbstractGraph):
print(f"Error projecting graph: {e}") print(f"Error projecting graph: {e}")
except Exception as ex: except Exception as ex:
print(f"Unexpected error: {ex}") print(f"Unexpected error: {ex}")
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
score = scored_result.score
node =self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
else:
print(f"Node with id {node_id} not found in the graph.")
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None: # :TODO: When we calculate edge embeddings in vector db change this similarly to node mapping
try:
# Step 1: Generate the query embedding
query_vector = await vector_engine.embed_data([query])
query_vector = query_vector[0]
if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.")
# Step 2: Collect all unique relationship types
unique_relationship_types = set()
for edge in self.edges:
relationship_type = edge.attributes.get('relationship_type')
if relationship_type:
unique_relationship_types.add(relationship_type)
# Step 3: Embed all unique relationship types
unique_relationship_types = list(unique_relationship_types)
relationship_type_embeddings = await vector_engine.embed_data(unique_relationship_types)
# Step 4: Map relationship types to their embeddings and calculate distances
embedding_map = {}
for relationship_type, embedding in zip(unique_relationship_types, relationship_type_embeddings):
edge_vector = np.array(embedding)
# Calculate cosine similarity
similarity = np.dot(query_vector, edge_vector) / (
np.linalg.norm(query_vector) * np.linalg.norm(edge_vector)
)
distance = 1 - similarity
# Round the distance to 4 decimal places and store it
embedding_map[relationship_type] = round(distance, 4)
# Step 4: Assign precomputed distances to edges
for edge in self.edges:
relationship_type = edge.attributes.get('relationship_type')
if not relationship_type or relationship_type not in embedding_map:
print(f"Edge {edge} has an unknown or missing relationship type.")
continue
# Assign the precomputed distance
edge.attributes["vector_distance"] = embedding_map[relationship_type]
except Exception as ex:
print(f"Error mapping vector distances to edges: {ex}")
async def calculate_top_triplet_importances(self, k = int) -> List:
min_heap = []
for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id)
source_distance = source_node.attributes.get("vector_distance", 0) if source_node else 0
target_distance = target_node.attributes.get("vector_distance", 0) if target_node else 0
edge_distance = edge.attributes.get("vector_distance", 0)
total_distance = source_distance + target_distance + edge_distance
heapq.heappush(min_heap, (-total_distance, i, edge))
if len(min_heap) > k:
heapq.heappop(min_heap)
return [edge for _, _, edge in sorted(min_heap)]

View file

@ -1,5 +1,5 @@
import numpy as np import numpy as np
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any, Union
class Node: class Node:
""" """
@ -21,6 +21,7 @@ class Node:
raise ValueError("Dimension must be a positive integer") raise ValueError("Dimension must be a positive integer")
self.id = node_id self.id = node_id
self.attributes = attributes if attributes is not None else {} self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float('inf')
self.skeleton_neighbours = [] self.skeleton_neighbours = []
self.skeleton_edges = [] self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int) self.status = np.ones(dimension, dtype=int)
@ -55,6 +56,12 @@ class Node:
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1 return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str) -> Union[str, int, float]:
return self.attributes[key]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Node({self.id}, attributes={self.attributes})" return f"Node({self.id}, attributes={self.attributes})"
@ -87,6 +94,7 @@ class Edge:
self.node1 = node1 self.node1 = node1
self.node2 = node2 self.node2 = node2
self.attributes = attributes if attributes is not None else {} self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float('inf')
self.directed = directed self.directed = directed
self.status = np.ones(dimension, dtype=int) self.status = np.ones(dimension, dtype=int)
@ -95,6 +103,12 @@ class Edge:
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1 return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str, value: Any) -> Union[str, int, float]:
return self.attributes[key]
def __repr__(self) -> str: def __repr__(self) -> str:
direction = "->" if self.directed else "--" direction = "->" if self.directed else "--"
return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})" return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})"

View file

@ -2,3 +2,4 @@ from .expand_with_nodes_and_edges import expand_with_nodes_and_edges
from .get_graph_from_model import get_graph_from_model from .get_graph_from_model import get_graph_from_model
from .get_model_instance_from_graph import get_model_instance_from_graph from .get_model_instance_from_graph import get_model_instance_from_graph
from .retrieve_existing_edges import retrieve_existing_edges from .retrieve_existing_edges import retrieve_existing_edges
from .convert_node_to_data_point import convert_node_to_data_point

View file

@ -0,0 +1,23 @@
from cognee.infrastructure.engine import DataPoint
def convert_node_to_data_point(node_data: dict) -> DataPoint:
subclass = find_subclass_by_name(DataPoint, node_data["type"])
return subclass(**node_data)
def get_all_subclasses(cls):
subclasses = []
for subclass in cls.__subclasses__():
subclasses.append(subclass)
subclasses.extend(get_all_subclasses(subclass)) # Recursively get subclasses
return subclasses
def find_subclass_by_name(cls, name):
for subclass in get_all_subclasses(cls):
if subclass.__name__ == name:
return subclass
return None

View file

@ -2,9 +2,18 @@ from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): async def get_graph_from_model(
data_point: DataPoint,
include_root = True,
added_nodes = None,
added_edges = None,
visited_properties = None,
):
nodes = [] nodes = []
edges = [] edges = []
added_nodes = added_nodes or {}
added_edges = added_edges or {}
visited_properties = visited_properties or {}
data_point_properties = {} data_point_properties = {}
excluded_properties = set() excluded_properties = set()
@ -13,10 +22,27 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
if field_name == "_metadata": if field_name == "_metadata":
continue continue
if field_value is None:
excluded_properties.add(field_name)
continue
if isinstance(field_value, DataPoint): if isinstance(field_value, DataPoint):
excluded_properties.add(field_name) excluded_properties.add(field_name)
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
if property_key in visited_properties:
return [], []
visited_properties[property_key] = 0
property_nodes, property_edges = await get_graph_from_model(
field_value,
True,
added_nodes,
added_edges,
visited_properties,
)
for node in property_nodes: for node in property_nodes:
if str(node.id) not in added_nodes: if str(node.id) not in added_nodes:
@ -47,7 +73,20 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
excluded_properties.add(field_name) excluded_properties.add(field_name)
for item in field_value: for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) property_key = f"{str(data_point.id)}{field_name}{str(item.id)}"
if property_key in visited_properties:
return [], []
visited_properties[property_key] = 0
property_nodes, property_edges = await get_graph_from_model(
item,
True,
added_nodes,
added_edges,
visited_properties,
)
for node in property_nodes: for node in property_nodes:
if str(node.id) not in added_nodes: if str(node.id) not in added_nodes:

View file

View file

View file

@ -0,0 +1,25 @@
from uuid import UUID
from enum import Enum
from typing import Callable, Dict
from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user
async def two_step_retriever(query: Dict[str, str], user: User = None) -> list:
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
own_document_ids = await get_document_ids_for_user(user.id)
retrieved_results = await diffusion_retriever(query, user)
filtered_search_results = []
return retrieved_results
async def diffusion_retriever(query: str, user, community_filter = []) -> list:
raise(NotImplementedError)

View file

@ -0,0 +1,25 @@
from uuid import UUID
from enum import Enum
from typing import Callable, Dict
from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user
async def two_step_retriever(query: Dict[str, str], user: User = None) -> list:
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
own_document_ids = await get_document_ids_for_user(user.id)
retrieved_results = await g_retriever(query, user)
filtered_search_results = []
return retrieved_results
async def g_retriever(query: str, user, community_filter = []) -> list:
raise(NotImplementedError)

View file

@ -0,0 +1,119 @@
import asyncio
from uuid import UUID
from enum import Enum
from typing import Callable, Dict
from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
def format_triplets(edges):
print("\n\n\n")
def filter_attributes(obj, attributes):
"""Helper function to filter out non-None properties, including nested dicts."""
result = {}
for attr in attributes:
value = getattr(obj, attr, None)
if value is not None:
# If the value is a dict, extract relevant keys from it
if isinstance(value, dict):
nested_values = {k: v for k, v in value.items() if k in attributes and v is not None}
result[attr] = nested_values
else:
result[attr] = value
return result
triplets = []
for edge in edges:
node1 = edge.node1
node2 = edge.node2
edge_attributes = edge.attributes
node1_attributes = node1.attributes
node2_attributes = node2.attributes
# Filter only non-None properties
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet
triplet = (
f"Node1: {node1_info}\n"
f"Edge: {edge_info}\n"
f"Node2: {node2_info}\n\n\n" # Add three blank lines for separation
)
triplets.append(triplet)
return "".join(triplets)
async def two_step_retriever(query: Dict[str, str], user: User = None) -> list:
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
own_document_ids = await get_document_ids_for_user(user.id)
retrieved_results = await run_two_step_retriever(query, user)
filtered_search_results = []
return retrieved_results
def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is just for now to fix vector db duplicates
results_dict = {}
for collection, results in zip(collections, results):
seen_ids = set()
unique_results = []
for result in results:
if result.id not in seen_ids:
unique_results.append(result)
seen_ids.add(result.id)
else:
print(f"Duplicate found in collection '{collection}': {result.id}")
results_dict[collection] = unique_results
return results_dict
async def run_two_step_retriever(query: str, user, community_filter = []) -> list:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
collections = ["Entity_name", "TextSummary_text", 'EntityType_name', 'DocumentChunk_text']
results = await asyncio.gather(
*[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections]
)
############################################# This part is a quick fix til we don't fix the vector db inconsistency
node_distances = delete_duplicated_vector_db_elements(collections, results)# :TODO: Change when vector db is fixed
# results_dict = {collection: result for collection, result in zip(collections, results)}
##############################################
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(graph_engine,
node_properties_to_project=['id',
'description',
'name',
'type',
'text'],
edge_properties_to_project=['id',
'relationship_name'])
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)# :TODO: This should be coming from vector db
results = await memory_fragment.calculate_top_triplet_importances(k=5)
print(format_triplets(results))
print(f'Query was the following:{query}' )
return results

View file

@ -1,13 +1,28 @@
from typing import List, Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class Repository(DataPoint): class Repository(DataPoint):
path: str path: str
type: Optional[str] = "Repository"
class CodeFile(DataPoint): class CodeFile(DataPoint):
extracted_id: str # actually file path extracted_id: str # actually file path
type: Optional[str] = "CodeFile"
source_code: Optional[str] = None
part_of: Optional[Repository] = None
depends_on: Optional[List["CodeFile"]] = None
depends_directly_on: Optional[List["CodeFile"]] = None
contains: Optional[List["CodePart"]] = None
_metadata: dict = {
"index_fields": ["source_code"]
}
class CodePart(DataPoint):
type: str type: str
# part_of: Optional[CodeFile]
source_code: str source_code: str
part_of: Repository type: Optional[str] = "CodePart"
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"] "index_fields": ["source_code"]
@ -18,3 +33,6 @@ class CodeRelationship(DataPoint):
target_id: str target_id: str
type: str # between files type: str # between files
relation: str # depends on or depends directly relation: str # depends on or depends directly
CodeFile.model_rebuild()
CodePart.model_rebuild()

View file

@ -1,7 +1,7 @@
import os import os
import asyncio import asyncio
import argparse import argparse
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
@ -15,7 +15,7 @@ def main():
print(f"Error: The provided repository path does not exist: {repo_path}") print(f"Error: The provided repository path does not exist: {repo_path}")
return return
graph = asyncio.run(get_repo_dependency_graph(repo_path)) graph = asyncio.run(get_repo_file_dependencies(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph)) graph = asyncio.run(enrich_dependency_graph(graph))
for node in graph.nodes: for node in graph.nodes:
print(f"Node: {node}") print(f"Node: {node}")

View file

@ -1,7 +1,7 @@
import os import os
import asyncio import asyncio
import argparse import argparse
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
from cognee.tasks.repo_processor.expand_dependency_graph import expand_dependency_graph from cognee.tasks.repo_processor.expand_dependency_graph import expand_dependency_graph
@ -16,7 +16,7 @@ def main():
print(f"Error: The provided repository path does not exist: {repo_path}") print(f"Error: The provided repository path does not exist: {repo_path}")
return return
graph = asyncio.run(get_repo_dependency_graph(repo_path)) graph = asyncio.run(get_repo_file_dependencies(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph)) graph = asyncio.run(enrich_dependency_graph(graph))
graph = expand_dependency_graph(graph) graph = expand_dependency_graph(graph)
for node in graph.nodes: for node in graph.nodes:

View file

@ -1,7 +1,7 @@
import os import os
import asyncio import asyncio
import argparse import argparse
from cognee.tasks.repo_processor.get_repo_dependency_graph import get_repo_dependency_graph from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
def main(): def main():
@ -14,7 +14,7 @@ def main():
print(f"Error: The provided repository path does not exist: {repo_path}") print(f"Error: The provided repository path does not exist: {repo_path}")
return return
graph = asyncio.run(get_repo_dependency_graph(repo_path)) graph = asyncio.run(get_repo_file_dependencies(repo_path))
for node in graph.nodes: for node in graph.nodes:
print(f"Node: {node}") print(f"Node: {node}")

View file

@ -1,16 +1,51 @@
from cognee.modules.data.models import Data from cognee.modules.data.models import Data
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument from cognee.modules.data.processing.document_types import (
Document,
PdfDocument,
AudioDocument,
ImageDocument,
TextDocument,
)
EXTENSION_TO_DOCUMENT_CLASS = { EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, "pdf": PdfDocument, # Text documents
"audio": AudioDocument, "txt": TextDocument,
"image": ImageDocument, "png": ImageDocument, # Image documents
"txt": TextDocument "dwg": ImageDocument,
"xcf": ImageDocument,
"jpg": ImageDocument,
"jpx": ImageDocument,
"apng": ImageDocument,
"gif": ImageDocument,
"webp": ImageDocument,
"cr2": ImageDocument,
"tif": ImageDocument,
"bmp": ImageDocument,
"jxr": ImageDocument,
"psd": ImageDocument,
"ico": ImageDocument,
"heic": ImageDocument,
"avif": ImageDocument,
"aac": AudioDocument, # Audio documents
"mid": AudioDocument,
"mp3": AudioDocument,
"m4a": AudioDocument,
"ogg": AudioDocument,
"flac": AudioDocument,
"wav": AudioDocument,
"amr": AudioDocument,
"aiff": AudioDocument,
} }
def classify_documents(data_documents: list[Data]) -> list[Document]: def classify_documents(data_documents: list[Data]) -> list[Document]:
documents = [ documents = [
EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location, name=data_item.name) EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](
id=data_item.id,
title=f"{data_item.name}.{data_item.extension}",
raw_data_location=data_item.raw_data_location,
name=data_item.name,
)
for data_item in data_documents for data_item in data_documents
] ]
return documents return documents

View file

@ -1,54 +0,0 @@
import os
import networkx as nx
from cognee.shared.CodeGraphEntities import CodeFile, CodeRelationship, Repository
from cognee.tasks.storage import add_data_points
async def convert_graph_from_code_graph(
graph: nx.DiGraph, repo_path: str
) -> tuple[str, list[CodeFile], list[CodeRelationship]]:
code_objects = code_objects_from_di_graph(graph, repo_path)
add_data_points(code_objects)
return code_objects
def create_code_file(path, type, repo):
abspath = os.path.abspath(path)
with open(abspath, "r") as f:
source_code = f.read()
code_file = CodeFile(
extracted_id = abspath,
type = type,
source_code = source_code,
part_of = repo,
)
return code_file
def code_objects_from_di_graph(
graph: nx.DiGraph, repo_path: str
) -> tuple[Repository, list[CodeFile], list[CodeRelationship]]:
repo = Repository(path=repo_path)
code_files = [
create_code_file(os.path.join(repo_path, path), "python_file", repo)
for path in graph.nodes
]
code_relationships = [
CodeRelationship(
os.path.join(repo_path, source),
os.path.join(repo_path, target),
"python_file",
graph.get_edge_data(source, target)["relation"],
)
for source, target in graph.edges
]
return (repo, code_files, code_relationships)

View file

@ -24,11 +24,13 @@ async def extract_graph_from_data(
(chunk, chunk_graph) for chunk, chunk_graph in zip(data_chunks, chunk_graphs) (chunk, chunk_graph) for chunk, chunk_graph in zip(data_chunks, chunk_graphs)
] ]
existing_edges_map = await retrieve_existing_edges( existing_edges_map = await retrieve_existing_edges(
chunk_and_chunk_graphs, graph_engine chunk_and_chunk_graphs,
graph_engine,
) )
graph_nodes, graph_edges = expand_with_nodes_and_edges( graph_nodes, graph_edges = expand_with_nodes_and_edges(
chunk_and_chunk_graphs, existing_edges_map chunk_and_chunk_graphs,
existing_edges_map,
) )
if len(graph_nodes) > 0: if len(graph_nodes) > 0:

View file

@ -4,4 +4,4 @@ logger = logging.getLogger("task:repo_processor")
from .enrich_dependency_graph import enrich_dependency_graph from .enrich_dependency_graph import enrich_dependency_graph
from .expand_dependency_graph import expand_dependency_graph from .expand_dependency_graph import expand_dependency_graph
from .get_repo_dependency_graph import get_repo_dependency_graph from .get_repo_file_dependencies import get_repo_file_dependencies

View file

@ -1,25 +1,38 @@
import asyncio
import networkx as nx import networkx as nx
from typing import Dict, List from typing import Dict, List
from tqdm.asyncio import tqdm
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile
from cognee.modules.graph.utils import get_graph_from_model, convert_node_to_data_point
from cognee.infrastructure.databases.graph import get_graph_engine
def topologically_sort_subgraph(subgraph_node_to_indegree: Dict[str, int], graph: nx.DiGraph) -> List[str]: def topologically_sort_subgraph(subgraph_node_to_indegree: Dict[str, int], graph: nx.DiGraph) -> List[str]:
"""Performs a topological sort on a subgraph based on node indegrees.""" """Performs a topological sort on a subgraph based on node indegrees."""
results = [] results = []
remaining_nodes = subgraph_node_to_indegree.copy() remaining_nodes = subgraph_node_to_indegree.copy()
while remaining_nodes: while remaining_nodes:
next_node = min(remaining_nodes, key=remaining_nodes.get) next_node = min(remaining_nodes, key=remaining_nodes.get)
results.append(next_node) results.append(next_node)
for successor in graph.successors(next_node): for successor in graph.successors(next_node):
if successor in remaining_nodes: if successor in remaining_nodes:
remaining_nodes[successor] -= 1 remaining_nodes[successor] -= 1
remaining_nodes.pop(next_node) remaining_nodes.pop(next_node)
return results return results
def topologically_sort(graph: nx.DiGraph) -> List[str]: def topologically_sort(graph: nx.DiGraph) -> List[str]:
"""Performs a topological sort on the entire graph.""" """Performs a topological sort on the entire graph."""
subgraphs = (graph.subgraph(c).copy() for c in nx.weakly_connected_components(graph)) subgraphs = (graph.subgraph(c).copy() for c in nx.weakly_connected_components(graph))
topological_order = [] topological_order = []
for subgraph in subgraphs: for subgraph in subgraphs:
node_to_indegree = { node_to_indegree = {
node: len(list(subgraph.successors(node))) node: len(list(subgraph.successors(node)))
@ -28,29 +41,84 @@ def topologically_sort(graph: nx.DiGraph) -> List[str]:
topological_order.extend( topological_order.extend(
topologically_sort_subgraph(node_to_indegree, subgraph) topologically_sort_subgraph(node_to_indegree, subgraph)
) )
return topological_order return topological_order
def node_enrich_and_connect(graph: nx.MultiDiGraph, topological_order: List[str], node: str) -> None: async def node_enrich_and_connect(
graph: nx.MultiDiGraph,
topological_order: List[str],
node: CodeFile,
data_points_map: Dict[str, DataPoint],
) -> None:
"""Adds 'depends_on' edges to the graph based on topological order.""" """Adds 'depends_on' edges to the graph based on topological order."""
topological_rank = topological_order.index(node) topological_rank = topological_order.index(node.id)
graph.nodes[node]['topological_rank'] = topological_rank node.topological_rank = topological_rank
node_descendants = nx.descendants(graph, node) node_descendants = nx.descendants(graph, node.id)
if graph.has_edge(node,node):
node_descendants.add(node) if graph.has_edge(node.id, node.id):
for desc in node_descendants: node_descendants.add(node.id)
if desc not in topological_order[:topological_rank+1]:
new_connections = []
graph_engine = await get_graph_engine()
for desc_id in node_descendants:
if desc_id not in topological_order[:topological_rank + 1]:
continue continue
graph.add_edge(node, desc, relation='depends_on')
async def enrich_dependency_graph(graph: nx.DiGraph) -> nx.MultiDiGraph: if desc_id in data_points_map:
desc = data_points_map[desc_id]
else:
node_data = await graph_engine.extract_node(desc_id)
desc = convert_node_to_data_point(node_data)
new_connections.append(desc)
node.depends_directly_on = node.depends_directly_on or []
node.depends_directly_on.extend(new_connections)
async def enrich_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]:
"""Enriches the graph with topological ranks and 'depends_on' edges.""" """Enriches the graph with topological ranks and 'depends_on' edges."""
graph = nx.MultiDiGraph(graph) nodes = []
edges = []
for data_point in data_points:
graph_nodes, graph_edges = await get_graph_from_model(data_point)
nodes.extend(graph_nodes)
edges.extend(graph_edges)
graph = nx.MultiDiGraph()
simple_nodes = [(node.id, node.model_dump()) for node in nodes]
graph.add_nodes_from(simple_nodes)
graph.add_edges_from(edges)
topological_order = topologically_sort(graph) topological_order = topologically_sort(graph)
node_rank_map = {node: idx for idx, node in enumerate(topological_order)} node_rank_map = {node: idx for idx, node in enumerate(topological_order)}
for node in graph.nodes:
if node not in node_rank_map: # for node_id, node in tqdm(graph.nodes(data = True), desc = "Enriching dependency graph", unit = "node"):
# if node_id not in node_rank_map:
# continue
# data_points.append(node_enrich_and_connect(graph, topological_order, node))
data_points_map = {data_point.id: data_point for data_point in data_points}
data_points_futures = []
for data_point in tqdm(data_points, desc = "Enriching dependency graph", unit = "data_point"):
if data_point.id not in node_rank_map:
continue continue
node_enrich_and_connect(graph, topological_order, node)
return graph if isinstance(data_point, CodeFile):
data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
# yield data_point
await asyncio.gather(*data_points_futures)
return data_points

View file

@ -1,28 +1,43 @@
import networkx as nx from uuid import NAMESPACE_OID, uuid5
# from tqdm import tqdm
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart
from cognee.tasks.repo_processor.extract_code_parts import extract_code_parts from cognee.tasks.repo_processor.extract_code_parts import extract_code_parts
from cognee.tasks.repo_processor import logger from cognee.tasks.repo_processor import logger
def _add_code_parts_nodes_and_edges(code_file: CodeFile, part_type, code_parts) -> None:
def _add_code_parts_nodes_and_edges(graph, parent_node_id, part_type, code_parts):
"""Add code part nodes and edges for a specific part type.""" """Add code part nodes and edges for a specific part type."""
if not code_parts: if not code_parts:
logger.debug(f"No code parts to add for parent_node_id {parent_node_id} and part_type {part_type}.") logger.debug(f"No code parts to add for node {code_file.id} and part_type {part_type}.")
return return
part_nodes = []
for idx, code_part in enumerate(code_parts): for idx, code_part in enumerate(code_parts):
if not code_part.strip(): if not code_part.strip():
logger.warning(f"Empty code part in parent_node_id {parent_node_id} and part_type {part_type}.") logger.warning(f"Empty code part in node {code_file.id} and part_type {part_type}.")
continue continue
part_node_id = f"{parent_node_id}_{part_type}_{idx}"
graph.add_node(part_node_id, source_code=code_part, node_type=part_type) part_node_id = uuid5(NAMESPACE_OID, f"{code_file.id}_{part_type}_{idx}")
graph.add_edge(parent_node_id, part_node_id, relation="contains")
part_nodes.append(CodePart(
id = part_node_id,
type = part_type,
# part_of = code_file,
source_code = code_part,
))
# graph.add_node(part_node_id, source_code=code_part, node_type=part_type)
# graph.add_edge(parent_node_id, part_node_id, relation="contains")
code_file.contains = code_file.contains or []
code_file.contains.extend(part_nodes)
def _process_single_node(graph, node_id, node_data): def _process_single_node(code_file: CodeFile) -> None:
"""Process a single Python file node.""" """Process a single Python file node."""
graph.nodes[node_id]["node_type"] = "python_file" node_id = code_file.id
source_code = node_data.get("source_code", "") source_code = code_file.source_code
if not source_code.strip(): if not source_code.strip():
logger.warning(f"Node {node_id} has no or empty 'source_code'. Skipping.") logger.warning(f"Node {node_id} has no or empty 'source_code'. Skipping.")
@ -35,15 +50,14 @@ def _process_single_node(graph, node_id, node_data):
return return
for part_type, code_parts in code_parts_dict.items(): for part_type, code_parts in code_parts_dict.items():
_add_code_parts_nodes_and_edges(graph, node_id, part_type, code_parts) _add_code_parts_nodes_and_edges(code_file, part_type, code_parts)
def expand_dependency_graph(graph: nx.MultiDiGraph) -> nx.MultiDiGraph: async def expand_dependency_graph(data_points: list[DataPoint]) -> list[DataPoint]:
"""Process Python file nodes, adding code part nodes and edges.""" """Process Python file nodes, adding code part nodes and edges."""
expanded_graph = graph.copy() # for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"):
for node_id, node_data in graph.nodes(data=True): for data_point in data_points:
if not node_data: # Check if node_data is empty if isinstance(data_point, CodeFile):
logger.warning(f"Node {node_id} has no data. Skipping.") _process_single_node(data_point)
continue
_process_single_node(expanded_graph, node_id, node_data) return data_points
return expanded_graph

View file

@ -1,61 +0,0 @@
import os
import aiofiles
import networkx as nx
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
async def get_py_path_and_source(file_path, repo_path):
relative_path = os.path.relpath(file_path, repo_path)
try:
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
source_code = await f.read()
return relative_path, source_code
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return relative_path, None
async def get_py_files_dict(repo_path):
"""Get .py files and their source code"""
if not os.path.exists(repo_path):
return {}
py_files_paths = (
os.path.join(root, file)
for root, _, files in os.walk(repo_path) for file in files if file.endswith(".py")
)
py_files_dict = {}
for file_path in py_files_paths:
relative_path, source_code = await get_py_path_and_source(file_path, repo_path)
py_files_dict[relative_path] = {"source_code": source_code}
return py_files_dict
def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bool = True) -> tuple:
if relative_paths:
file_path = os.path.relpath(file_path, repo_path)
dependency = os.path.relpath(dependency, repo_path)
return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_dependency_graph(repo_path: str) -> nx.DiGraph:
"""Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path)
dependency_graph = nx.DiGraph()
dependency_graph.add_nodes_from(py_files_dict.items())
for file_path, metadata in py_files_dict.items():
source_code = metadata.get("source_code")
if source_code is None:
continue
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
dependency_graph.add_edges_from(dependency_edges)
return dependency_graph

View file

@ -0,0 +1,87 @@
import os
from uuid import NAMESPACE_OID, uuid5
import aiofiles
from tqdm.asyncio import tqdm
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, Repository
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
async def get_py_path_and_source(file_path):
try:
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
source_code = await f.read()
return file_path, source_code
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return file_path, None
async def get_py_files_dict(repo_path):
"""Get .py files and their source code"""
if not os.path.exists(repo_path):
return {}
py_files_paths = (
os.path.join(root, file)
for root, _, files in os.walk(repo_path) for file in files if file.endswith(".py")
)
py_files_dict = {}
for file_path in py_files_paths:
absolute_path = os.path.abspath(file_path)
relative_path, source_code = await get_py_path_and_source(absolute_path)
py_files_dict[relative_path] = {"source_code": source_code}
return py_files_dict
def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bool = False) -> tuple:
if relative_paths:
file_path = os.path.relpath(file_path, repo_path)
dependency = os.path.relpath(dependency, repo_path)
return (file_path, dependency, {"relation": "depends_directly_on"})
async def get_repo_file_dependencies(repo_path: str) -> list[DataPoint]:
"""Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path)
repo = Repository(
id = uuid5(NAMESPACE_OID, repo_path),
path = repo_path,
)
data_points = [repo]
# dependency_graph = nx.DiGraph()
# dependency_graph.add_nodes_from(py_files_dict.items())
async for file_path, metadata in tqdm(py_files_dict.items(), desc="Repo dependency graph", unit="file"):
source_code = metadata.get("source_code")
if source_code is None:
continue
dependencies = await get_local_script_dependencies(os.path.join(repo_path, file_path), repo_path)
data_points.append(CodeFile(
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
part_of = repo,
depends_on = [
CodeFile(
id = uuid5(NAMESPACE_OID, dependency),
extracted_id = dependency,
part_of = repo,
) for dependency in dependencies
] if len(dependencies) else None,
))
# dependency_edges = [get_edge(file_path, dependency, repo_path) for dependency in dependencies]
# dependency_graph.add_edges_from(dependency_edges)
return data_points
# return dependency_graph

View file

@ -1,3 +1,4 @@
import asyncio
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.utils import get_graph_from_model from cognee.modules.graph.utils import get_graph_from_model
@ -8,11 +9,13 @@ async def add_data_points(data_points: list[DataPoint]):
nodes = [] nodes = []
edges = [] edges = []
for data_point in data_points: results = await asyncio.gather(*[
property_nodes, property_edges = get_graph_from_model(data_point) get_graph_from_model(data_point) for data_point in data_points
])
nodes.extend(property_nodes) for result_nodes, result_edges in results:
edges.extend(property_edges) nodes.extend(result_nodes)
edges.extend(result_edges)
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()

View file

@ -16,6 +16,9 @@ async def index_data_points(data_points: list[DataPoint]):
data_point_type = type(data_point) data_point_type = type(data_point)
for field_name in data_point._metadata["index_fields"]: for field_name in data_point._metadata["index_fields"]:
if getattr(data_point, field_name, None) is None:
continue
index_name = f"{data_point_type.__tablename__}.{field_name}" index_name = f"{data_point_type.__tablename__}.{field_name}"
if index_name not in created_indexes: if index_name not in created_indexes:
@ -35,12 +38,21 @@ async def index_data_points(data_points: list[DataPoint]):
return data_points return data_points
def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> list[DataPoint]: def get_data_points_from_model(data_point: DataPoint, added_data_points = None, visited_properties = None) -> list[DataPoint]:
data_points = [] data_points = []
added_data_points = added_data_points or {}
visited_properties = visited_properties or {}
for field_name, field_value in data_point: for field_name, field_value in data_point:
if isinstance(field_value, DataPoint): if isinstance(field_value, DataPoint):
new_data_points = get_data_points_from_model(field_value, added_data_points) property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
if property_key in visited_properties:
return []
visited_properties[property_key] = True
new_data_points = get_data_points_from_model(field_value, added_data_points, visited_properties)
for new_point in new_data_points: for new_point in new_data_points:
if str(new_point.id) not in added_data_points: if str(new_point.id) not in added_data_points:
@ -49,7 +61,14 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
for field_value_item in field_value: for field_value_item in field_value:
new_data_points = get_data_points_from_model(field_value_item, added_data_points) property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
if property_key in visited_properties:
return []
visited_properties[property_key] = True
new_data_points = get_data_points_from_model(field_value_item, added_data_points, visited_properties)
for new_point in new_data_points: for new_point in new_data_points:
if str(new_point.id) not in added_data_points: if str(new_point.id) not in added_data_points:
@ -79,4 +98,3 @@ if __name__ == "__main__":
data_points = get_data_points_from_model(person) data_points = get_data_points_from_model(person)
print(data_points) print(data_points)

View file

@ -4,6 +4,7 @@ from uuid import uuid5
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.shared.CodeGraphEntities import CodeFile from cognee.shared.CodeGraphEntities import CodeFile
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
@ -12,13 +13,16 @@ from .models import CodeSummary
async def summarize_code( async def summarize_code(
code_files: list[CodeFile], summarization_model: Type[BaseModel] code_files: list[DataPoint],
) -> list[CodeFile]: summarization_model: Type[BaseModel],
) -> list[DataPoint]:
if len(code_files) == 0: if len(code_files) == 0:
return code_files return code_files
code_files_data_points = [file for file in code_files if isinstance(file, CodeFile)]
file_summaries = await asyncio.gather( file_summaries = await asyncio.gather(
*[extract_summary(file.source_code, summarization_model) for file in code_files] *[extract_summary(file.source_code, summarization_model) for file in code_files_data_points]
) )
summaries = [ summaries = [
@ -27,9 +31,9 @@ async def summarize_code(
made_from = file, made_from = file,
text = file_summaries[file_index].summary, text = file_summaries[file_index].summary,
) )
for (file_index, file) in enumerate(code_files) for (file_index, file) in enumerate(code_files_data_points)
] ]
await add_data_points(summaries) await add_data_points(summaries)
return code_files, summaries return code_files

View file

@ -1,51 +0,0 @@
import random
import string
import numpy as np
from cognee.shared.CodeGraphEntities import CodeFile, CodeRelationship
def random_str(n, spaces=True):
candidates = string.ascii_letters + string.digits
if spaces:
candidates += " "
return "".join(random.choice(candidates) for _ in range(n))
def code_graph_test_data_generation():
nodes = [
CodeFile(
extracted_id=random_str(10, spaces=False),
type="file",
source_code=random_str(random.randrange(50, 500)),
)
for _ in range(100)
]
n_nodes = len(nodes)
first_source = np.random.randint(0, n_nodes)
reached_nodes = {first_source}
last_iteration = [first_source]
edges = []
while len(reached_nodes) < n_nodes:
for source in last_iteration:
last_iteration = []
tries = 0
while ((len(last_iteration) == 0 or tries < 500)) and (
len(reached_nodes) < n_nodes
):
tries += 1
target = np.random.randint(n_nodes)
if target not in reached_nodes:
last_iteration.append(target)
edges.append(
CodeRelationship(
source_id=nodes[source].extracted_id,
target_id=nodes[target].extracted_id,
type="files",
relation="depends",
)
)
reached_nodes = reached_nodes.union(set(last_iteration))
return (nodes, edges)

View file

@ -1,27 +0,0 @@
import asyncio
import pytest
from cognee.shared.CodeGraphEntities import Repository
from cognee.tasks.graph.convert_graph_from_code_graph import (
convert_graph_from_code_graph,
)
from cognee.tests.tasks.graph.code_graph_test_data_generation import (
code_graph_test_data_generation,
)
def test_convert_graph_from_code_graph():
repo = Repository(path="test/repo/path")
nodes, edges = code_graph_test_data_generation()
repo_out, nodes_out, edges_out = asyncio.run(
convert_graph_from_code_graph(repo, nodes, edges)
)
assert repo == repo_out, f"{repo = } != {repo_out = }"
for node_in, node_out in zip(nodes, nodes_out):
assert node_in == node_out, f"{node_in = } != {node_out = }"
for edge_in, edge_out in zip(edges, edges_out):
assert edge_in == edge_out, f"{edge_in = } != {edge_out = }"

View file

@ -0,0 +1,100 @@
import asyncio
import random
import time
from typing import List
from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model
random.seed(1500)
class Repository(DataPoint):
path: str
class CodeFile(DataPoint):
part_of: Repository
contains: List["CodePart"] = []
depends_on: List["CodeFile"] = []
source_code: str
class CodePart(DataPoint):
part_of: CodeFile
source_code: str
CodeFile.model_rebuild()
CodePart.model_rebuild()
def nanoseconds_to_largest_unit(nanoseconds):
# Define conversion factors
conversion_factors = {
'weeks': 7 * 24 * 60 * 60 * 1e9,
'days': 24 * 60 * 60 * 1e9,
'hours': 60 * 60 * 1e9,
'minutes': 60 * 1e9,
'seconds': 1e9,
'miliseconds': 1e6,
'microseconds': 1e3,
}
# Iterate through conversion factors to find the largest unit
for unit, factor in conversion_factors.items():
converted_value = nanoseconds / factor
if converted_value >= 1:
return converted_value, unit
# If nanoseconds is smaller than a second
return nanoseconds, 'nanoseconds'
async def test_circular_reference_extraction():
repo = Repository(path = "repo1")
code_files = [CodeFile(
id = uuid5(NAMESPACE_OID, f"file{file_index}"),
source_code = "source code",
part_of = repo,
contains = [],
depends_on = [CodeFile(
id = uuid5(NAMESPACE_OID, f"file{random_id}"),
source_code = "source code",
part_of = repo,
depends_on = [],
) for random_id in [random.randint(0, 1499) for _ in range(random.randint(0, 5))]],
) for file_index in range(1500)]
for code_file in code_files:
code_file.contains.extend([CodePart(
part_of = code_file,
source_code = f"Part {part_index}",
) for part_index in range(random.randint(1, 20))])
nodes = []
edges = []
start = time.perf_counter_ns()
results = await asyncio.gather(*[
get_graph_from_model(code_file) for code_file in code_files
])
time_to_run = time.perf_counter_ns() - start
print(nanoseconds_to_largest_unit(time_to_run))
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
# for code_file in code_files:
# model_nodes, model_edges = get_graph_from_model(code_file)
# nodes.extend(model_nodes)
# edges.extend(model_edges)
assert len(nodes) == 1501
assert len(edges) == 1501 * 20 + 1500 * 5
if __name__ == "__main__":
asyncio.run(test_circular_reference_extraction())

View file

@ -8,7 +8,7 @@ def test_node_initialization():
"""Test that a Node is initialized correctly.""" """Test that a Node is initialized correctly."""
node = Node("node1", {"attr1": "value1"}, dimension=2) node = Node("node1", {"attr1": "value1"}, dimension=2)
assert node.id == "node1" assert node.id == "node1"
assert node.attributes == {"attr1": "value1"} assert node.attributes == {"attr1": "value1", 'vector_distance': np.inf}
assert len(node.status) == 2 assert len(node.status) == 2
assert np.all(node.status == 1) assert np.all(node.status == 1)
@ -95,7 +95,7 @@ def test_edge_initialization():
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
assert edge.node1 == node1 assert edge.node1 == node1
assert edge.node2 == node2 assert edge.node2 == node2
assert edge.attributes == {"weight": 10} assert edge.attributes == {'vector_distance': np.inf,"weight": 10}
assert edge.directed is False assert edge.directed is False
assert len(edge.status) == 2 assert len(edge.status) == 2
assert np.all(edge.status == 1) assert np.all(edge.status == 1)

View file

@ -77,11 +77,11 @@ def test_get_edges_success(setup_graph):
graph.add_node(node2) graph.add_node(node2)
edge = Edge(node1, node2) edge = Edge(node1, node2)
graph.add_edge(edge) graph.add_edge(edge)
assert edge in graph.get_edges("node1") assert edge in graph.get_edges_of_node("node1")
def test_get_edges_nonexistent_node(setup_graph): def test_get_edges_nonexistent_node(setup_graph):
"""Test retrieving edges for a nonexistent node raises an exception.""" """Test retrieving edges for a nonexistent node raises an exception."""
graph = setup_graph graph = setup_graph
with pytest.raises(ValueError, match="Node with id nonexistent does not exist."): with pytest.raises(ValueError, match="Node with id nonexistent does not exist."):
graph.get_edges("nonexistent") graph.get_edges_of_node("nonexistent")

View file

@ -46,7 +46,7 @@ services:
- 7687:7687 - 7687:7687
environment: environment:
- NEO4J_AUTH=neo4j/pleaseletmein - NEO4J_AUTH=neo4j/pleaseletmein
- NEO4J_PLUGINS=["apoc"] - NEO4J_PLUGINS=["apoc", "graph-data-science"]
networks: networks:
- cognee-network - cognee-network

View file

@ -8,30 +8,63 @@ from swebench.harness.utils import load_swebench_dataset
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
import cognee import cognee
from cognee.shared.data_models import SummarizedContent
from cognee.shared.utils import render_graph
from cognee.tasks.repo_processor import (
enrich_dependency_graph,
expand_dependency_graph,
get_repo_file_dependencies,
)
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_code
from cognee.modules.pipelines import Task, run_tasks
from cognee.api.v1.cognify.code_graph_pipeline import code_graph_pipeline from cognee.api.v1.cognify.code_graph_pipeline import code_graph_pipeline
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from evals.eval_utils import download_instances from evals.eval_utils import download_instances
from evals.eval_utils import ingest_repos
from evals.eval_utils import download_github_repo
from evals.eval_utils import delete_repo
async def generate_patch_with_cognee(instance):
async def generate_patch_with_cognee(instance, search_type=SearchType.CHUNKS):
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system()
dataset_name = "SWE_test_data" #dataset_name = "SWE_test_data"
code_text = instance["text"]
await cognee.add([code_text], dataset_name) #await cognee.add('', dataset_name = dataset_name)
await code_graph_pipeline([dataset_name])
graph_engine = await get_graph_engine() # repo_path = download_github_repo(instance, '../RAW_GIT_REPOS')
with open(graph_engine.filename, "r") as f:
graph_str = f.read() repo_path = '/Users/borisarzentar/Projects/graphrag'
tasks = [
Task(get_repo_file_dependencies),
Task(add_data_points),
Task(enrich_dependency_graph),
Task(expand_dependency_graph),
Task(add_data_points),
# Task(summarize_code, summarization_model = SummarizedContent),
]
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
async for result in pipeline:
print(result)
print('Here we have the repo under the repo_path')
await render_graph()
problem_statement = instance['problem_statement'] problem_statement = instance['problem_statement']
instructions = read_query_prompt("patch_gen_instructions.txt") instructions = read_query_prompt("patch_gen_instructions.txt")
graph_str = 'HERE WE SHOULD PASS THE TRIPLETS FROM GRAPHRAG'
prompt = "\n".join([ prompt = "\n".join([
instructions, instructions,
"<patch>", "<patch>",
@ -41,14 +74,18 @@ async def generate_patch_with_cognee(instance, search_type=SearchType.CHUNKS):
graph_str graph_str
]) ])
return 0
''' :TODO: We have to find out how do we do the generation
llm_client = get_llm_client() llm_client = get_llm_client()
answer_prediction = await llm_client.acreate_structured_output( answer_prediction = await llm_client.acreate_structured_output(
text_input=problem_statement, text_input=problem_statement,
system_prompt=prompt, system_prompt=prompt,
response_model=str, response_model=str,
) )
return answer_prediction
return answer_prediction
'''
async def generate_patch_without_cognee(instance): async def generate_patch_without_cognee(instance):
problem_statement = instance['problem_statement'] problem_statement = instance['problem_statement']
@ -71,11 +108,16 @@ async def get_preds(dataset, with_cognee=True):
model_name = "without_cognee" model_name = "without_cognee"
pred_func = generate_patch_without_cognee pred_func = generate_patch_without_cognee
for instance in dataset:
await pred_func(instance)
'''
preds = [{"instance_id": instance["instance_id"], preds = [{"instance_id": instance["instance_id"],
"model_patch": await pred_func(instance), "model_patch": await pred_func(instance),
"model_name_or_path": model_name} for instance in dataset] "model_name_or_path": model_name} for instance in dataset]
'''
return preds return 0
async def main(): async def main():
@ -115,4 +157,5 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main(), debug=True) asyncio.run(main(), debug=True)

View file

@ -8,7 +8,8 @@ from swebench.inference.make_datasets.create_instance import make_code_text
from swebench.inference.make_datasets.utils import (AutoContextManager, from swebench.inference.make_datasets.utils import (AutoContextManager,
ingest_directory_contents) ingest_directory_contents)
from tqdm.auto import tqdm from tqdm.auto import tqdm
from git import Repo
import shutil
def ingest_files(filenames): def ingest_files(filenames):
files_dict = dict() files_dict = dict()
@ -101,3 +102,56 @@ def download_instances(
dataset = create_dataset(input_instances_with_text) dataset = create_dataset(input_instances_with_text)
dataset.save_to_disk(path) dataset.save_to_disk(path)
return dataset return dataset
def download_github_repo(instance, output_dir):
"""
Downloads a GitHub repository and checks out the specified commit.
Args:
instance (dict): Dictionary containing 'repo', 'base_commit', and 'instance_id'.
output_dir (str): Directory to store the downloaded repositories.
Returns:
str: Path to the downloaded repository.
"""
repo_owner_repo = instance['repo']
base_commit = instance['base_commit']
instance_id = instance['instance_id']
repo_url = f"https://github.com/{repo_owner_repo}.git"
repo_path = os.path.abspath(os.path.join(output_dir, instance_id))
# Clone repository if it doesn't already exist
if not os.path.exists(repo_path):
print(f"Cloning {repo_url} to {repo_path}...")
Repo.clone_from(repo_url, repo_path)
else:
print(f"Repository already exists at {repo_path}.")
repo = Repo(repo_path)
repo.git.checkout(base_commit)
return repo_path
def delete_repo(repo_path):
"""
Deletes the specified repository directory.
Args:
repo_path (str): Path to the repository to delete.
Returns:
None
"""
try:
if os.path.exists(repo_path):
shutil.rmtree(repo_path)
print(f"Deleted repository at {repo_path}.")
else:
print(f"Repository path {repo_path} does not exist. Nothing to delete.")
except Exception as e:
print(f"Error deleting repository at {repo_path}: {e}")

View file

@ -1,22 +1,11 @@
import argparse
import asyncio import asyncio
import os
from cognee.modules.pipelines import Task, run_tasks from cognee.modules.pipelines import Task, run_tasks
from cognee.shared.CodeGraphEntities import CodeRelationship, Repository
from cognee.shared.data_models import SummarizedContent
from cognee.tasks.code.get_local_dependencies_checker import (
get_local_script_dependencies,
)
from cognee.tasks.graph.convert_graph_from_code_graph import (
create_code_file,
convert_graph_from_code_graph,
)
from cognee.tasks.repo_processor import ( from cognee.tasks.repo_processor import (
enrich_dependency_graph, enrich_dependency_graph,
expand_dependency_graph, expand_dependency_graph,
get_repo_dependency_graph, get_repo_file_dependencies,
) )
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_code from cognee.tasks.summarization import summarize_code
@ -24,58 +13,24 @@ async def print_results(pipeline):
async for result in pipeline: async for result in pipeline:
print(result) print(result)
async def get_local_script_dependencies_wrapper(script_path, repo_path):
dependencies = await get_local_script_dependencies(script_path, repo_path)
return (script_path, dependencies)
async def scan_repo(path, condition):
futures = []
for root, dirs, files in os.walk(path):
for file in files:
if condition(file):
futures.append(
get_local_script_dependencies_wrapper(
os.path.abspath(f"{root}/{file}"), path
)
)
results = await asyncio.gather(*futures)
code_files = {}
code_relationships = []
for abspath, dependencies in results:
code_file, abspath = create_code_file(abspath, "python_file")
code_files[abspath] = code_file
for dependency in dependencies:
dependency_code_file, dependency_abspath = create_code_file(
dependency, "python_file"
)
code_files[dependency_abspath] = dependency_code_file
code_relationship = CodeRelationship(
source_id=abspath,
target_id=dependency_abspath,
type="files",
relation="depends_on",
)
code_relationships.append(code_relationship)
return (Repository(path=path), list(code_files.values()), code_relationships)
if __name__ == "__main__": if __name__ == "__main__":
'''
parser = argparse.ArgumentParser(description="Process a file path") parser = argparse.ArgumentParser(description="Process a file path")
parser.add_argument("path", help="Path to the file") parser.add_argument("path", help="Path to the file")
args = parser.parse_args() args = parser.parse_args()
abspath = os.path.abspath(args.path or ".") abspath = os.path.abspath(args.path or ".")
'''
abspath = '/Users/laszlohajdu/Documents/Github/RAW_GIT_REPOS/astropy__astropy-12907'
tasks = [ tasks = [
Task(get_repo_dependency_graph), Task(get_repo_file_dependencies),
Task(add_data_points),
Task(enrich_dependency_graph), Task(enrich_dependency_graph),
Task(expand_dependency_graph), Task(expand_dependency_graph),
Task(convert_graph_from_code_graph), Task(add_data_points),
Task(summarize_code, summarization_model = SummarizedContent), # Task(summarize_code, summarization_model = SummarizedContent),
] ]
pipeline = run_tasks(tasks, abspath, "cognify_code_pipeline") pipeline = run_tasks(tasks, abspath, "cognify_code_pipeline")
asyncio.run(print_results(pipeline)) asyncio.run(print_results(pipeline))

View file

@ -1,32 +1,6 @@
import cognee import cognee
import asyncio import asyncio
from cognee.api.v1.search import SearchType from cognee.pipelines.retriever.two_steps_retriever import two_step_retriever
job_position = """0:Senior Data Scientist (Machine Learning)
Company: TechNova Solutions
Location: San Francisco, CA
Job Description:
TechNova Solutions is seeking a Senior Data Scientist specializing in Machine Learning to join our dynamic analytics team. The ideal candidate will have a strong background in developing and deploying machine learning models, working with large datasets, and translating complex data into actionable insights.
Responsibilities:
Develop and implement advanced machine learning algorithms and models.
Analyze large, complex datasets to extract meaningful patterns and insights.
Collaborate with cross-functional teams to integrate predictive models into products.
Stay updated with the latest advancements in machine learning and data science.
Mentor junior data scientists and provide technical guidance.
Qualifications:
Masters or Ph.D. in Data Science, Computer Science, Statistics, or a related field.
5+ years of experience in data science and machine learning.
Proficient in Python, R, and SQL.
Experience with deep learning frameworks (e.g., TensorFlow, PyTorch).
Strong problem-solving skills and attention to detail.
Candidate CVs
"""
job_1 = """ job_1 = """
CV 1: Relevant CV 1: Relevant
@ -195,7 +169,7 @@ async def main(enable_steps):
# Step 2: Add text # Step 2: Add text
if enable_steps.get("add_text"): if enable_steps.get("add_text"):
text_list = [job_position, job_1, job_2, job_3, job_4, job_5] text_list = [job_1, job_2, job_3, job_4, job_5]
for text in text_list: for text in text_list:
await cognee.add(text) await cognee.add(text)
print(f"Added text: {text[:35]}...") print(f"Added text: {text[:35]}...")
@ -206,24 +180,21 @@ async def main(enable_steps):
print("Knowledge graph created.") print("Knowledge graph created.")
# Step 4: Query insights # Step 4: Query insights
if enable_steps.get("search_insights"): if enable_steps.get("retriever"):
search_results = await cognee.search( await two_step_retriever('Who has Phd?')
SearchType.INSIGHTS,
{'query': 'Which applicant has the most relevant experience in data science?'}
)
print("Search results:")
for result_text in search_results:
print(result_text)
if __name__ == '__main__': if __name__ == '__main__':
# Flags to enable/disable steps # Flags to enable/disable steps
rebuild_kg = True
retrieve = True
steps_to_enable = { steps_to_enable = {
"prune_data": True, "prune_data": rebuild_kg,
"prune_system": True, "prune_system": rebuild_kg,
"add_text": True, "add_text": rebuild_kg,
"cognify": True, "cognify": rebuild_kg,
"search_insights": True "retriever": retrieve
} }
asyncio.run(main(steps_to_enable)) asyncio.run(main(steps_to_enable))

View file

@ -0,0 +1,48 @@
import os
import asyncio
import pathlib
import cognee
from cognee.api.v1.search import SearchType
# Prerequisites:
# 1. Copy `.env.template` and rename it to `.env`.
# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field:
# LLM_API_KEY = "your_key_here"
async def main():
# Create a clean slate for cognee -- reset data and system state
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# cognee knowledge graph will be created based on the text
# and description of these files
mp3_file_path = os.path.join(
pathlib.Path(__file__).parent.parent.parent,
".data/multimedia/text_to_speech.mp3",
)
png_file_path = os.path.join(
pathlib.Path(__file__).parent.parent.parent,
".data/multimedia/example.png",
)
# Add the files, and make it available for cognify
await cognee.add([mp3_file_path, png_file_path])
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()
# Query cognee for summaries of the data in the multimedia files
search_results = await cognee.search(
SearchType.SUMMARIES,
query_text="What is in the multimedia files?",
)
# Display search results
for result_text in search_results:
print(result_text)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,5 +1,4 @@
import asyncio import asyncio
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
@ -11,29 +10,57 @@ from cognee.api.v1.search import SearchType
async def main(): async def main():
# Create a clean slate for cognee -- reset data and system state # Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text # cognee knowledge graph will be created based on this text
text = """ text = """
Natural language processing (NLP) is an interdisciplinary Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval. subfield of computer science and information retrieval.
""" """
print("Adding text to cognee:")
print(text.strip())
# Add the text, and make it available for cognify # Add the text, and make it available for cognify
await cognee.add(text) await cognee.add(text)
print("Text added successfully.\n")
print("Running cognify to create knowledge graph...\n")
print("Cognify process steps:")
print("1. Classifying the document: Determining the type and category of the input text.")
print("2. Checking permissions: Ensuring the user has the necessary rights to process the text.")
print("3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis.")
print("4. Adding data points: Storing the extracted chunks for processing.")
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
# Use LLMs and cognee to create knowledge graph # Use LLMs and cognee to create knowledge graph
await cognee.cognify() await cognee.cognify()
print("Cognify process complete.\n")
query_text = 'Tell me about NLP'
print(f"Searching cognee for insights with query: '{query_text}'")
# Query cognee for insights on the added text # Query cognee for insights on the added text
search_results = await cognee.search( search_results = await cognee.search(
SearchType.INSIGHTS, query_text='Tell me about NLP' SearchType.INSIGHTS, query_text=query_text
) )
# Display search results print("Search results:")
# Display results
for result_text in search_results: for result_text in search_results:
print(result_text) print(result_text)
# Example output:
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
# (...)
# It represents nodes and relationships in the knowledge graph:
# - The first element is the source node (e.g., 'natural language processing').
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
# - The third element is the target node (e.g., 'computer science').
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(main()) asyncio.run(main())

View file

@ -265,7 +265,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"id": "df16431d0f48b006", "id": "df16431d0f48b006",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -304,7 +304,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"id": "9086abf3af077ab4", "id": "9086abf3af077ab4",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -349,7 +349,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"id": "a9de0cc07f798b7f", "id": "a9de0cc07f798b7f",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -393,7 +393,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"id": "185ff1c102d06111", "id": "185ff1c102d06111",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -437,7 +437,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"id": "d55ce4c58f8efb67", "id": "d55ce4c58f8efb67",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -479,7 +479,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"id": "ca4ecc32721ad332", "id": "ca4ecc32721ad332",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -529,14 +529,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"id": "bce39dc6", "id": "bce39dc6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"\n", "\n",
"# # Setting environment variables\n", "# Setting environment variables\n",
"if \"GRAPHISTRY_USERNAME\" not in os.environ: \n", "if \"GRAPHISTRY_USERNAME\" not in os.environ: \n",
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n", " os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
"\n", "\n",
@ -546,24 +546,26 @@
"if \"LLM_API_KEY\" not in os.environ:\n", "if \"LLM_API_KEY\" not in os.environ:\n",
" os.environ[\"LLM_API_KEY\"] = \"\"\n", " os.environ[\"LLM_API_KEY\"] = \"\"\n",
"\n", "\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" # \"neo4j\" or \"networkx\"\n", "# \"neo4j\" or \"networkx\"\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
"# Not needed if using networkx\n", "# Not needed if using networkx\n",
"#GRAPH_DATABASE_URL=\"\"\n", "#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
"#GRAPH_DATABASE_USERNAME=\"\"\n", "#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
"#GRAPH_DATABASE_PASSWORD=\"\"\n", "#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
"\n", "\n",
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" # \"qdrant\", \"weaviate\" or \"lancedb\"\n", "# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
"# Not needed if using \"lancedb\"\n", "os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n", "# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n", "# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
"\n", "\n",
"# Database provider\n", "# Relational Database provider \"sqlite\" or \"postgres\"\n",
"os.environ[\"DB_PROVIDER\"]=\"sqlite\" # or \"postgres\"\n", "os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
"\n", "\n",
"# Database name\n", "# Database name\n",
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n", "os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
"\n", "\n",
"# Postgres specific parameters (Only if Postgres is run)\n", "# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n", "# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"# os.environ[\"DB_PORT\"]=\"5432\"\n", "# os.environ[\"DB_PORT\"]=\"5432\"\n",
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n", "# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
@ -620,7 +622,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"id": "7c431fdef4921ae0", "id": "7c431fdef4921ae0",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {

View file

@ -52,7 +52,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -90,23 +90,23 @@
"# \"neo4j\" or \"networkx\"\n", "# \"neo4j\" or \"networkx\"\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n", "os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
"# Not needed if using networkx\n", "# Not needed if using networkx\n",
"#GRAPH_DATABASE_URL=\"\"\n", "#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
"#GRAPH_DATABASE_USERNAME=\"\"\n", "#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
"#GRAPH_DATABASE_PASSWORD=\"\"\n", "#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
"\n", "\n",
"# \"qdrant\", \"weaviate\" or \"lancedb\"\n", "# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n", "os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
"# Not needed if using \"lancedb\"\n", "# Not needed if using \"lancedb\" or \"pgvector\"\n",
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n", "# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n", "# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
"\n", "\n",
"# Database provider\n", "# Relational Database provider \"sqlite\" or \"postgres\"\n",
"os.environ[\"DB_PROVIDER\"]=\"sqlite\" # or \"postgres\"\n", "os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
"\n", "\n",
"# Database name\n", "# Database name\n",
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n", "os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
"\n", "\n",
"# Postgres specific parameters (Only if Postgres is run)\n", "# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n", "# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"# os.environ[\"DB_PORT\"]=\"5432\"\n", "# os.environ[\"DB_PORT\"]=\"5432\"\n",
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n", "# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
@ -130,8 +130,6 @@
"\n", "\n",
"from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables\n", "from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables\n",
"from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables\n", "from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables\n",
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
"from cognee.shared.utils import render_graph\n",
"from cognee.modules.users.models import User\n", "from cognee.modules.users.models import User\n",
"from cognee.modules.users.methods import get_default_user\n", "from cognee.modules.users.methods import get_default_user\n",
"from cognee.tasks.ingestion.ingest_data_with_metadata import ingest_data_with_metadata\n", "from cognee.tasks.ingestion.ingest_data_with_metadata import ingest_data_with_metadata\n",
@ -196,6 +194,9 @@
"source": [ "source": [
"import graphistry\n", "import graphistry\n",
"\n", "\n",
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
"from cognee.shared.utils import render_graph\n",
"\n",
"# Get graph\n", "# Get graph\n",
"graphistry.login(username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\"))\n", "graphistry.login(username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\"))\n",
"graph_engine = await get_graph_engine()\n", "graph_engine = await get_graph_engine()\n",

View file

@ -0,0 +1,169 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cognee GraphRAG with Multimedia files"
]
},
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"source": [
"## Load Data\n",
"\n",
"We will use a few sample multimedia files which we have on GitHub for easy access."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"\n",
"# cognee knowledge graph will be created based on the text\n",
"# and description of these files\n",
"mp3_file_path = os.path.join(\n",
" os.path.abspath(''), \"../\",\n",
" \".data/multimedia/text_to_speech.mp3\",\n",
")\n",
"png_file_path = os.path.join(\n",
" os.path.abspath(''), \"../\",\n",
" \".data/multimedia/example.png\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set environment variables"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Setting environment variables\n",
"if \"GRAPHISTRY_USERNAME\" not in os.environ: \n",
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
"\n",
"if \"GRAPHISTRY_PASSWORD\" not in os.environ: \n",
" os.environ[\"GRAPHISTRY_PASSWORD\"] = \"\"\n",
"\n",
"if \"LLM_API_KEY\" not in os.environ:\n",
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
"\n",
"# \"neo4j\" or \"networkx\"\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
"# Not needed if using networkx\n",
"#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
"#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
"#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
"\n",
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
"\n",
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
"os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
"\n",
"# Database name\n",
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
"\n",
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
"# os.environ[\"DB_PASSWORD\"]=\"cognee\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run Cognee with multimedia files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cognee\n",
"\n",
"# Create a clean slate for cognee -- reset data and system state\n",
"await cognee.prune.prune_data()\n",
"await cognee.prune.prune_system(metadata=True)\n",
"\n",
"# Add multimedia files and make them available for cognify\n",
"await cognee.add([mp3_file_path, png_file_path])\n",
"\n",
"# Create knowledge graph with cognee\n",
"await cognee.cognify()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Query Cognee for summaries related to multimedia files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cognee.api.v1.search import SearchType\n",
"\n",
"# Query cognee for summaries of the data in the multimedia files\n",
"search_results = await cognee.search(\n",
" SearchType.SUMMARIES,\n",
" query_text=\"What is in the multimedia files?\",\n",
")\n",
"\n",
"# Display search results\n",
"for result_text in search_results:\n",
" print(result_text)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

5
poetry.lock generated
View file

@ -6171,11 +6171,6 @@ files = [
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
{file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
{file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},