Adds Nodefilter functionality for the SF demo (updated)

This commit is contained in:
hajdul88 2025-04-17 16:43:59 +02:00 committed by lxobr
parent 85e5e69494
commit 97974fdc89
8 changed files with 144 additions and 172 deletions

View file

@ -1,5 +1,6 @@
from typing import Union
from typing import Union, Optional, Type, List
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.search.types import SearchType
from cognee.modules.users.exceptions import UserNotFoundError
from cognee.modules.users.models import User
@ -13,6 +14,9 @@ async def search(
user: User = None,
datasets: Union[list[str], str, None] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
@ -25,7 +29,14 @@ async def search(
raise UserNotFoundError
filtered_search_results = await search_function(
query_text, query_type, datasets, user, system_prompt_path=system_prompt_path
query_text,
query_type,
datasets,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
return filtered_search_results

View file

@ -1,4 +1,4 @@
from typing import Protocol, Optional, Dict, Any
from typing import Protocol, Optional, Dict, Any, Type, List
from abc import abstractmethod
@ -51,6 +51,10 @@ class GraphDBInterface(Protocol):
):
raise NotImplementedError
@abstractmethod
async def get_subgraph(self, node_type: Type[Any], node_name: List[str]):
raise NotImplementedError
@abstractmethod
async def get_graph_data(self):
raise NotImplementedError

View file

@ -4,7 +4,7 @@ import json
from cognee.shared.logging_utils import get_logger, ERROR
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
@ -496,6 +496,58 @@ class Neo4jAdapter(GraphDBInterface):
return (nodes, edges)
async def get_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
label = node_type.__name__
query = f"""
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
result = await self.query(query, {"names": node_name})
if not result:
return [], []
raw_nodes = result[0]["rawNodes"]
raw_rels = result[0]["rawRels"]
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
edges = [
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
)
for r in raw_rels
]
return nodes, edges
async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.

View file

@ -1,5 +1,5 @@
from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union
from typing import List, Dict, Union, Optional, Type
from cognee.exceptions import InvalidValueError
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
@ -61,12 +61,18 @@ class CogneeGraph(CogneeAbstractGraph):
node_dimension=1,
edge_dimension=1,
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers")
try:
if len(memory_fragment_filter) == 0:
if node_type is not None and node_name is not None:
nodes_data, edges_data = await adapter.get_subgraph(
node_type=node_type, node_name=node_name
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(
@ -74,9 +80,11 @@ class CogneeGraph(CogneeAbstractGraph):
)
if not nodes_data:
raise EntityNotFoundError(message="No node data retrieved from the database.")
#:TODO: quick and dirty solution for sf demo, as the list of nodes can be empty
return None
if not edges_data:
raise EntityNotFoundError(message="No edge data retrieved from the database.")
#:TODO: quick and dirty solution for sf demo, as the list of edges can be empty
return None
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}

View file

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Optional, Type, List
from collections import Counter
import string
@ -19,11 +19,15 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
):
"""Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
def _get_nodes(self, retrieved_edges: list) -> dict:
"""Creates a dictionary of nodes with their names and content."""
@ -69,11 +73,16 @@ class GraphCompletionRetriever(BaseRetriever):
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search(
query, top_k=self.top_k, collections=vector_index_collections or None
query,
top_k=self.top_k,
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
)
if len(found_triplets) == 0:
raise NoRelevantDataFound
#:TODO: quick and dirty solution for sf demo, as the triplets can be empty
return []
return found_triplets

View file

@ -1,6 +1,6 @@
import asyncio
from cognee.shared.logging_utils import get_logger, ERROR
from typing import List, Optional
from typing import List, Optional, Type
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
@ -54,6 +54,8 @@ def format_triplets(edges):
async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
graph_engine = await get_graph_engine()
@ -66,6 +68,8 @@ async def get_memory_fragment(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
node_type=node_type,
node_name=node_name,
)
return memory_fragment
@ -78,6 +82,8 @@ async def brute_force_triplet_search(
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
) -> list:
if user is None:
user = await get_default_user()
@ -92,6 +98,8 @@ async def brute_force_triplet_search(
collections=collections,
properties_to_project=properties_to_project,
memory_fragment=memory_fragment,
node_type=node_type,
node_name=node_name,
)
return retrieved_results
@ -103,6 +111,8 @@ async def brute_force_search(
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@ -114,6 +124,8 @@ async def brute_force_search(
collections (Optional[List[str]]): List of collections to query.
properties_to_project (Optional[List[str]]): List of properties to project.
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
Returns:
list: The top triplet results.
@ -124,7 +136,9 @@ async def brute_force_search(
raise ValueError("top_k must be a positive integer.")
if memory_fragment is None:
memory_fragment = await get_memory_fragment(properties_to_project)
memory_fragment = await get_memory_fragment(
properties_to_project=properties_to_project, node_type=node_type, node_name=node_name
)
if collections is None:
collections = [

View file

@ -1,5 +1,5 @@
import json
from typing import Callable
from typing import Callable, Optional, Type, List
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine.utils import parse_id
@ -11,6 +11,7 @@ from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionR
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
@ -28,12 +29,21 @@ async def search(
datasets: list[str],
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
):
query = await log_query(query_text, query_type.value, user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(
query_type, query_text, user, system_prompt_path=system_prompt_path
query_type,
query_text,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
filtered_search_results = []
@ -51,7 +61,13 @@ async def search(
async def specific_search(
query_type: SearchType, query: str, user: User, system_prompt_path="answer_simple_question.txt"
query_type: SearchType,
query: str,
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: List[Optional[str]] = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever().get_completion,
@ -61,7 +77,10 @@ async def specific_search(
system_prompt_path=system_prompt_path
).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path

View file

@ -1,163 +1,15 @@
import cognee
import asyncio
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.metrics.operations import get_pipeline_run_metrics
from cognee.modules.engine.models.Entity import Entity
from cognee.api.v1.search import SearchType
job_1 = """
CV 1: Relevant
Name: Dr. Emily Carter
Contact Information:
Email: emily.carter@example.com
Phone: (555) 123-4567
Summary:
Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.
Education:
Ph.D. in Computer Science, Stanford University (2014)
B.S. in Mathematics, University of California, Berkeley (2010)
Experience:
Senior Data Scientist, InnovateAI Labs (2016 Present)
Led a team in developing machine learning models for natural language processing applications.
Implemented deep learning algorithms that improved prediction accuracy by 25%.
Collaborated with cross-functional teams to integrate models into cloud-based platforms.
Data Scientist, DataWave Analytics (2014 2016)
Developed predictive models for customer segmentation and churn analysis.
Analyzed large datasets using Hadoop and Spark frameworks.
Skills:
Programming Languages: Python, R, SQL
Machine Learning: TensorFlow, Keras, Scikit-Learn
Big Data Technologies: Hadoop, Spark
Data Visualization: Tableau, Matplotlib
"""
job_2 = """
CV 2: Relevant
Name: Michael Rodriguez
Contact Information:
Email: michael.rodriguez@example.com
Phone: (555) 234-5678
Summary:
Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.
Education:
M.S. in Data Science, Carnegie Mellon University (2013)
B.S. in Computer Science, University of Michigan (2011)
Experience:
Senior Data Scientist, Alpha Analytics (2017 Present)
Developed machine learning models to optimize marketing strategies.
Reduced customer acquisition cost by 15% through predictive modeling.
Data Scientist, TechInsights (2013 2017)
Analyzed user behavior data to improve product features.
Implemented A/B testing frameworks to evaluate product changes.
Skills:
Programming Languages: Python, Java, SQL
Machine Learning: Scikit-Learn, XGBoost
Data Visualization: Seaborn, Plotly
Databases: MySQL, MongoDB
"""
job_3 = """
CV 3: Relevant
Name: Sarah Nguyen
Contact Information:
Email: sarah.nguyen@example.com
Phone: (555) 345-6789
Summary:
Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.
Education:
M.S. in Statistics, University of Washington (2014)
B.S. in Applied Mathematics, University of Texas at Austin (2012)
Experience:
Data Scientist, QuantumTech (2016 Present)
Designed and implemented machine learning algorithms for financial forecasting.
Improved model efficiency by 20% through algorithm optimization.
Junior Data Scientist, DataCore Solutions (2014 2016)
Assisted in developing predictive models for supply chain optimization.
Conducted data cleaning and preprocessing on large datasets.
Skills:
Programming Languages: Python, R
Machine Learning Frameworks: PyTorch, Scikit-Learn
Statistical Analysis: SAS, SPSS
Cloud Platforms: AWS, Azure
"""
job_4 = """
CV 4: Not Relevant
Name: David Thompson
Contact Information:
Email: david.thompson@example.com
Phone: (555) 456-7890
Summary:
Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.
Education:
B.F.A. in Graphic Design, Rhode Island School of Design (2012)
Experience:
Senior Graphic Designer, CreativeWorks Agency (2015 Present)
Led design projects for clients in various industries.
Created branding materials that increased client engagement by 30%.
Graphic Designer, Visual Innovations (2012 2015)
Designed marketing collateral, including brochures, logos, and websites.
Collaborated with the marketing team to develop cohesive brand strategies.
Skills:
Design Software: Adobe Photoshop, Illustrator, InDesign
Web Design: HTML, CSS
Specialties: Branding and Identity, Typography
"""
job_5 = """
CV 5: Not Relevant
Name: Jessica Miller
Contact Information:
Email: jessica.miller@example.com
Phone: (555) 567-8901
Summary:
Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.
Education:
B.A. in Business Administration, University of Southern California (2010)
Experience:
Sales Manager, Global Enterprises (2015 Present)
Managed a sales team of 15 members, achieving a 20% increase in annual revenue.
Developed sales strategies that expanded customer base by 25%.
Sales Representative, Market Leaders Inc. (2010 2015)
Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.
Skills:
Sales Strategy and Planning
Team Leadership and Development
CRM Software: Salesforce, Zoho
Negotiation and Relationship Building
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""
@ -173,7 +25,7 @@ async def main(enable_steps):
# Step 2: Add text
if enable_steps.get("add_text"):
text_list = [job_1, job_2, job_3, job_4, job_5]
text_list = [job_1]
for text in text_list:
await cognee.add(text)
print(f"Added text: {text[:35]}...")
@ -191,7 +43,10 @@ async def main(enable_steps):
# Step 5: Query insights
if enable_steps.get("retriever"):
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="Who has experience in design tools?"
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is computer science?",
node_type=Entity,
node_name=["computer science"],
)
print(search_results)