Adds Nodefilter functionality for the SF demo (updated)
This commit is contained in:
parent
85e5e69494
commit
97974fdc89
8 changed files with 144 additions and 172 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Union
|
||||
from typing import Union, Optional, Type, List
|
||||
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.exceptions import UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -13,6 +14,9 @@ async def search(
|
|||
user: User = None,
|
||||
datasets: Union[list[str], str, None] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
|
|
@ -25,7 +29,14 @@ async def search(
|
|||
raise UserNotFoundError
|
||||
|
||||
filtered_search_results = await search_function(
|
||||
query_text, query_type, datasets, user, system_prompt_path=system_prompt_path
|
||||
query_text,
|
||||
query_type,
|
||||
datasets,
|
||||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Protocol, Optional, Dict, Any
|
||||
from typing import Protocol, Optional, Dict, Any, Type, List
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
|
|
@ -51,6 +51,10 @@ class GraphDBInterface(Protocol):
|
|||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_subgraph(self, node_type: Type[Any], node_name: List[str]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_graph_data(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -496,6 +496,58 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"names": node_name})
|
||||
if not result:
|
||||
return [], []
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
|
||||
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
|
||||
edges = [
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
)
|
||||
for r in raw_rels
|
||||
]
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Dict, Union
|
||||
from typing import List, Dict, Union, Optional, Type
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
|
|
@ -61,12 +61,18 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node_dimension=1,
|
||||
edge_dimension=1,
|
||||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
if len(memory_fragment_filter) == 0:
|
||||
if node_type is not None and node_name is not None:
|
||||
nodes_data, edges_data = await adapter.get_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
|
|
@ -74,9 +80,11 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
)
|
||||
|
||||
if not nodes_data:
|
||||
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
||||
#:TODO: quick and dirty solution for sf demo, as the list of nodes can be empty
|
||||
return None
|
||||
if not edges_data:
|
||||
raise EntityNotFoundError(message="No edge data retrieved from the database.")
|
||||
#:TODO: quick and dirty solution for sf demo, as the list of edges can be empty
|
||||
return None
|
||||
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
|
|
@ -19,11 +19,15 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
|
||||
def _get_nodes(self, retrieved_edges: list) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
|
|
@ -69,11 +73,16 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query, top_k=self.top_k, collections=vector_index_collections or None
|
||||
query,
|
||||
top_k=self.top_k,
|
||||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
)
|
||||
|
||||
if len(found_triplets) == 0:
|
||||
raise NoRelevantDataFound
|
||||
#:TODO: quick and dirty solution for sf demo, as the triplets can be empty
|
||||
return []
|
||||
|
||||
return found_triplets
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
@ -54,6 +54,8 @@ def format_triplets(edges):
|
|||
|
||||
async def get_memory_fragment(
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
graph_engine = await get_graph_engine()
|
||||
|
|
@ -66,6 +68,8 @@ async def get_memory_fragment(
|
|||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
return memory_fragment
|
||||
|
|
@ -78,6 +82,8 @@ async def brute_force_triplet_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
|
@ -92,6 +98,8 @@ async def brute_force_triplet_search(
|
|||
collections=collections,
|
||||
properties_to_project=properties_to_project,
|
||||
memory_fragment=memory_fragment,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
return retrieved_results
|
||||
|
||||
|
|
@ -103,6 +111,8 @@ async def brute_force_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -114,6 +124,8 @@ async def brute_force_search(
|
|||
collections (Optional[List[str]]): List of collections to query.
|
||||
properties_to_project (Optional[List[str]]): List of properties to project.
|
||||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
|
|
@ -124,7 +136,9 @@ async def brute_force_search(
|
|||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(properties_to_project)
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, Type, List
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
|
|
@ -11,6 +11,7 @@ from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionR
|
|||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
|
|
@ -28,12 +29,21 @@ async def search(
|
|||
datasets: list[str],
|
||||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
):
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(
|
||||
query_type, query_text, user, system_prompt_path=system_prompt_path
|
||||
query_type,
|
||||
query_text,
|
||||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
filtered_search_results = []
|
||||
|
|
@ -51,7 +61,13 @@ async def search(
|
|||
|
||||
|
||||
async def specific_search(
|
||||
query_type: SearchType, query: str, user: User, system_prompt_path="answer_simple_question.txt"
|
||||
query_type: SearchType,
|
||||
query: str,
|
||||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever().get_completion,
|
||||
|
|
@ -61,7 +77,10 @@ async def specific_search(
|
|||
system_prompt_path=system_prompt_path
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path
|
||||
|
|
|
|||
|
|
@ -1,163 +1,15 @@
|
|||
import cognee
|
||||
import asyncio
|
||||
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.metrics.operations import get_pipeline_run_metrics
|
||||
|
||||
from cognee.modules.engine.models.Entity import Entity
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
job_1 = """
|
||||
CV 1: Relevant
|
||||
Name: Dr. Emily Carter
|
||||
Contact Information:
|
||||
|
||||
Email: emily.carter@example.com
|
||||
Phone: (555) 123-4567
|
||||
Summary:
|
||||
|
||||
Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.
|
||||
|
||||
Education:
|
||||
|
||||
Ph.D. in Computer Science, Stanford University (2014)
|
||||
B.S. in Mathematics, University of California, Berkeley (2010)
|
||||
Experience:
|
||||
|
||||
Senior Data Scientist, InnovateAI Labs (2016 – Present)
|
||||
Led a team in developing machine learning models for natural language processing applications.
|
||||
Implemented deep learning algorithms that improved prediction accuracy by 25%.
|
||||
Collaborated with cross-functional teams to integrate models into cloud-based platforms.
|
||||
Data Scientist, DataWave Analytics (2014 – 2016)
|
||||
Developed predictive models for customer segmentation and churn analysis.
|
||||
Analyzed large datasets using Hadoop and Spark frameworks.
|
||||
Skills:
|
||||
|
||||
Programming Languages: Python, R, SQL
|
||||
Machine Learning: TensorFlow, Keras, Scikit-Learn
|
||||
Big Data Technologies: Hadoop, Spark
|
||||
Data Visualization: Tableau, Matplotlib
|
||||
"""
|
||||
|
||||
job_2 = """
|
||||
CV 2: Relevant
|
||||
Name: Michael Rodriguez
|
||||
Contact Information:
|
||||
|
||||
Email: michael.rodriguez@example.com
|
||||
Phone: (555) 234-5678
|
||||
Summary:
|
||||
|
||||
Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.
|
||||
|
||||
Education:
|
||||
|
||||
M.S. in Data Science, Carnegie Mellon University (2013)
|
||||
B.S. in Computer Science, University of Michigan (2011)
|
||||
Experience:
|
||||
|
||||
Senior Data Scientist, Alpha Analytics (2017 – Present)
|
||||
Developed machine learning models to optimize marketing strategies.
|
||||
Reduced customer acquisition cost by 15% through predictive modeling.
|
||||
Data Scientist, TechInsights (2013 – 2017)
|
||||
Analyzed user behavior data to improve product features.
|
||||
Implemented A/B testing frameworks to evaluate product changes.
|
||||
Skills:
|
||||
|
||||
Programming Languages: Python, Java, SQL
|
||||
Machine Learning: Scikit-Learn, XGBoost
|
||||
Data Visualization: Seaborn, Plotly
|
||||
Databases: MySQL, MongoDB
|
||||
"""
|
||||
|
||||
|
||||
job_3 = """
|
||||
CV 3: Relevant
|
||||
Name: Sarah Nguyen
|
||||
Contact Information:
|
||||
|
||||
Email: sarah.nguyen@example.com
|
||||
Phone: (555) 345-6789
|
||||
Summary:
|
||||
|
||||
Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.
|
||||
|
||||
Education:
|
||||
|
||||
M.S. in Statistics, University of Washington (2014)
|
||||
B.S. in Applied Mathematics, University of Texas at Austin (2012)
|
||||
Experience:
|
||||
|
||||
Data Scientist, QuantumTech (2016 – Present)
|
||||
Designed and implemented machine learning algorithms for financial forecasting.
|
||||
Improved model efficiency by 20% through algorithm optimization.
|
||||
Junior Data Scientist, DataCore Solutions (2014 – 2016)
|
||||
Assisted in developing predictive models for supply chain optimization.
|
||||
Conducted data cleaning and preprocessing on large datasets.
|
||||
Skills:
|
||||
|
||||
Programming Languages: Python, R
|
||||
Machine Learning Frameworks: PyTorch, Scikit-Learn
|
||||
Statistical Analysis: SAS, SPSS
|
||||
Cloud Platforms: AWS, Azure
|
||||
"""
|
||||
|
||||
|
||||
job_4 = """
|
||||
CV 4: Not Relevant
|
||||
Name: David Thompson
|
||||
Contact Information:
|
||||
|
||||
Email: david.thompson@example.com
|
||||
Phone: (555) 456-7890
|
||||
Summary:
|
||||
|
||||
Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.
|
||||
|
||||
Education:
|
||||
|
||||
B.F.A. in Graphic Design, Rhode Island School of Design (2012)
|
||||
Experience:
|
||||
|
||||
Senior Graphic Designer, CreativeWorks Agency (2015 – Present)
|
||||
Led design projects for clients in various industries.
|
||||
Created branding materials that increased client engagement by 30%.
|
||||
Graphic Designer, Visual Innovations (2012 – 2015)
|
||||
Designed marketing collateral, including brochures, logos, and websites.
|
||||
Collaborated with the marketing team to develop cohesive brand strategies.
|
||||
Skills:
|
||||
|
||||
Design Software: Adobe Photoshop, Illustrator, InDesign
|
||||
Web Design: HTML, CSS
|
||||
Specialties: Branding and Identity, Typography
|
||||
"""
|
||||
|
||||
|
||||
job_5 = """
|
||||
CV 5: Not Relevant
|
||||
Name: Jessica Miller
|
||||
Contact Information:
|
||||
|
||||
Email: jessica.miller@example.com
|
||||
Phone: (555) 567-8901
|
||||
Summary:
|
||||
|
||||
Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.
|
||||
|
||||
Education:
|
||||
|
||||
B.A. in Business Administration, University of Southern California (2010)
|
||||
Experience:
|
||||
|
||||
Sales Manager, Global Enterprises (2015 – Present)
|
||||
Managed a sales team of 15 members, achieving a 20% increase in annual revenue.
|
||||
Developed sales strategies that expanded customer base by 25%.
|
||||
Sales Representative, Market Leaders Inc. (2010 – 2015)
|
||||
Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.
|
||||
Skills:
|
||||
|
||||
Sales Strategy and Planning
|
||||
Team Leadership and Development
|
||||
CRM Software: Salesforce, Zoho
|
||||
Negotiation and Relationship Building
|
||||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -173,7 +25,7 @@ async def main(enable_steps):
|
|||
|
||||
# Step 2: Add text
|
||||
if enable_steps.get("add_text"):
|
||||
text_list = [job_1, job_2, job_3, job_4, job_5]
|
||||
text_list = [job_1]
|
||||
for text in text_list:
|
||||
await cognee.add(text)
|
||||
print(f"Added text: {text[:35]}...")
|
||||
|
|
@ -191,7 +43,10 @@ async def main(enable_steps):
|
|||
# Step 5: Query insights
|
||||
if enable_steps.get("retriever"):
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="Who has experience in design tools?"
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is computer science?",
|
||||
node_type=Entity,
|
||||
node_name=["computer science"],
|
||||
)
|
||||
print(search_results)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue