Compare commits

...
Sign in to create a new pull request.

5 commits

4 changed files with 446 additions and 0 deletions

View file

View file

@ -0,0 +1,122 @@
from typing import List
import json
from dataclasses import dataclass
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
@dataclass
class TripletNode:
"""Graph node representation from triplet data."""
id: str
attributes: dict
@dataclass
class TripletEdge:
"""Graph edge representation from triplet data."""
node1: TripletNode
node2: TripletNode
attributes: dict
class EmbeddedTripletCompletionRetriever(GraphCompletionRetriever):
"""Retriever that uses embedded triplets from vector collection instead of brute force search.
Data Conversion Rationale:
GraphCompletionRetriever expects graph objects with .node1/.node2 having .id and .attributes,
but our triplets are stored as JSON: {"start_node": {"id", "content"}, "relationship", "end_node": {"id", "content"}}.
We convert triplet format TripletNode/TripletEdge objects to match the expected interface.
"""
def __init__(self, collection_name: str = "Triplets", **kwargs):
"""Initialize with configurable collection name."""
super().__init__(**kwargs)
self.collection_name = collection_name
async def get_triplets(self, query: str) -> List[TripletEdge]:
"""Override parent method to use vector search on triplet collection."""
vector_engine = get_vector_engine()
search_results = await vector_engine.search(
collection_name=self.collection_name, query_text=query, limit=self.top_k
)
triplet_edges = []
for search_result in search_results:
triplet_edge = self._convert_search_result_to_edge(search_result)
triplet_edges.append(triplet_edge)
return triplet_edges
def _parse_triplet_payload(self, search_result: ScoredResult) -> dict:
"""Extract triplet data from search result payload."""
# ScoredResult.payload contains entire DataPoint structure
# Actual triplet JSON is nested at payload['payload']
triplet_json = search_result.payload["payload"]
return json.loads(triplet_json)
def _create_triplet_node(self, node_data: dict) -> TripletNode:
"""Convert triplet node data to graph node format."""
return TripletNode(
id=node_data["id"],
attributes={
"text": node_data.get("content", ""),
"name": node_data.get("content", "Unnamed Node"),
"description": node_data.get("content", ""),
},
)
def _create_triplet_edge(
self, node1: TripletNode, node2: TripletNode, relationship: str
) -> TripletEdge:
"""Create graph edge from two nodes and relationship."""
return TripletEdge(node1=node1, node2=node2, attributes={"relationship_type": relationship})
def _convert_search_result_to_edge(self, search_result) -> TripletEdge:
"""Main conversion method: ScoredResult → TripletEdge."""
# 1. Extract triplet data from search result
triplet_data = self._parse_triplet_payload(search_result)
# 2. Extract triplet components
start_node_data = triplet_data["start_node"]
relationship = triplet_data["relationship"]
end_node_data = triplet_data["end_node"]
# 3. Create triplet nodes
start_node = self._create_triplet_node(start_node_data)
end_node = self._create_triplet_node(end_node_data)
# 4. Create triplet edge
triplet_edge = self._create_triplet_edge(start_node, end_node, relationship)
return triplet_edge
if __name__ == "__main__":
import asyncio
async def main():
retriever = EmbeddedTripletCompletionRetriever(collection_name="Triplets", top_k=5)
# Test triplet retrieval with CV/resume data
query = "machine learning experience Python data scientist"
print(f"Searching for: {query}")
triplets = await retriever.get_triplets(query)
print(f"Found {len(triplets)} triplets")
# Test full context generation
context = await retriever.get_context(query)
print(f"Generated context:\n{context}")
# Test another query
query2 = "Stanford University education background"
print(f"\nSecond search: {query2}")
triplets2 = await retriever.get_triplets(query2)
print(f"Found {len(triplets2)} triplets")
asyncio.run(main())

View file

@ -0,0 +1,188 @@
import asyncio
import cognee
from cognee.shared.logging_utils import setup_logging, INFO
from cognee.triplet_embedding_poc.triplet_embedding_postprocessing import (
triplet_embedding_postprocessing,
)
job_1 = """
CV 1: Relevant
Name: Dr. Emily Carter
Contact Information:
Email: emily.carter@example.com
Phone: (555) 123-4567
Summary:
Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.
Education:
Ph.D. in Computer Science, Stanford University (2014)
B.S. in Mathematics, University of California, Berkeley (2010)
Experience:
Senior Data Scientist, InnovateAI Labs (2016 Present)
Led a team in developing machine learning models for natural language processing applications.
Implemented deep learning algorithms that improved prediction accuracy by 25%.
Collaborated with cross-functional teams to integrate models into cloud-based platforms.
Data Scientist, DataWave Analytics (2014 2016)
Developed predictive models for customer segmentation and churn analysis.
Analyzed large datasets using Hadoop and Spark frameworks.
Skills:
Programming Languages: Python, R, SQL
Machine Learning: TensorFlow, Keras, Scikit-Learn
Big Data Technologies: Hadoop, Spark
Data Visualization: Tableau, Matplotlib
"""
job_2 = """
CV 2: Relevant
Name: Michael Rodriguez
Contact Information:
Email: michael.rodriguez@example.com
Phone: (555) 234-5678
Summary:
Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.
Education:
M.S. in Data Science, Carnegie Mellon University (2013)
B.S. in Computer Science, University of Michigan (2011)
Experience:
Senior Data Scientist, Alpha Analytics (2017 Present)
Developed machine learning models to optimize marketing strategies.
Reduced customer acquisition cost by 15% through predictive modeling.
Data Scientist, TechInsights (2013 2017)
Analyzed user behavior data to improve product features.
Implemented A/B testing frameworks to evaluate product changes.
Skills:
Programming Languages: Python, Java, SQL
Machine Learning: Scikit-Learn, XGBoost
Data Visualization: Seaborn, Plotly
Databases: MySQL, MongoDB
"""
job_3 = """
CV 3: Relevant
Name: Sarah Nguyen
Contact Information:
Email: sarah.nguyen@example.com
Phone: (555) 345-6789
Summary:
Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.
Education:
M.S. in Statistics, University of Washington (2014)
B.S. in Applied Mathematics, University of Texas at Austin (2012)
Experience:
Data Scientist, QuantumTech (2016 Present)
Designed and implemented machine learning algorithms for financial forecasting.
Improved model efficiency by 20% through algorithm optimization.
Junior Data Scientist, DataCore Solutions (2014 2016)
Assisted in developing predictive models for supply chain optimization.
Conducted data cleaning and preprocessing on large datasets.
Skills:
Programming Languages: Python, R
Machine Learning Frameworks: PyTorch, Scikit-Learn
Statistical Analysis: SAS, SPSS
Cloud Platforms: AWS, Azure
"""
job_4 = """
CV 4: Not Relevant
Name: David Thompson
Contact Information:
Email: david.thompson@example.com
Phone: (555) 456-7890
Summary:
Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.
Education:
B.F.A. in Graphic Design, Rhode Island School of Design (2012)
Experience:
Senior Graphic Designer, CreativeWorks Agency (2015 Present)
Led design projects for clients in various industries.
Created branding materials that increased client engagement by 30%.
Graphic Designer, Visual Innovations (2012 2015)
Designed marketing collateral, including brochures, logos, and websites.
Collaborated with the marketing team to develop cohesive brand strategies.
Skills:
Design Software: Adobe Photoshop, Illustrator, InDesign
Web Design: HTML, CSS
Specialties: Branding and Identity, Typography
"""
job_5 = """
CV 5: Not Relevant
Name: Jessica Miller
Contact Information:
Email: jessica.miller@example.com
Phone: (555) 567-8901
Summary:
Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.
Education:
B.A. in Business Administration, University of Southern California (2010)
Experience:
Sales Manager, Global Enterprises (2015 Present)
Managed a sales team of 15 members, achieving a 20% increase in annual revenue.
Developed sales strategies that expanded customer base by 25%.
Sales Representative, Market Leaders Inc. (2010 2015)
Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.
Skills:
Sales Strategy and Planning
Team Leadership and Development
CRM Software: Salesforce, Zoho
Negotiation and Relationship Building
"""
async def main():
pre_graph_creation = True
if pre_graph_creation:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
text_list = [job_1, job_2, job_3, job_4, job_5]
for text in text_list:
await cognee.add(text)
print(f"Added text: {text[:35]}...")
await cognee.cognify()
await triplet_embedding_postprocessing()
if __name__ == "__main__":
logger = setup_logging(log_level=INFO)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -0,0 +1,136 @@
from typing import Any, List
import uuid
import json
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.pipelines.tasks.task import Task
from cognee.infrastructure.engine import DataPoint
logger = get_logger("triplet_embedding_poc")
def create_triplet_data_point(triplet: dict) -> "TripletDataPoint":
start_node = triplet.get("start_node", None)
if start_node:
start_node_string = start_node.get("content", None)
else:
start_node_string = ""
relationship = triplet.get("relationship", "")
end_node = triplet.get("end_node", None)
if end_node:
end_node_string = end_node.get("content", None)
else:
end_node_string = ""
start_node_type = triplet.get("start_node_type", "")
end_node_type = triplet.get("end_node_type", "")
triplet_str = (
start_node_string
+ "-"
+ start_node_type
+ "-"
+ relationship
+ "-"
+ end_node_string
+ "-"
+ end_node_type
)
triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str)
return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str)
class TripletDataPoint(DataPoint):
"""DataPoint for storing graph triplets with embedded text representation."""
payload: str
text: str
metadata: dict = {"index_fields": ["text"]}
def extract_node_data(node_dict):
"""Extract relevant data from a node dictionary."""
result = {"id": node_dict["id"]}
if "metadata" not in node_dict:
return result
metadata = json.loads(node_dict["metadata"])
if "index_fields" not in metadata or not metadata["index_fields"]:
return result
index_field_name = metadata["index_fields"][0] # Always one entry
if index_field_name not in node_dict:
return result
result["content"] = node_dict[index_field_name]
result["index_field_name"] = index_field_name
return result
async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
graph_engine = await get_graph_engine()
counter = 0
offset = 0
while True:
query = f"""
MATCH (start_node)-[relationship]->(end_node)
RETURN start_node, relationship, end_node
SKIP {offset} LIMIT {triplets_batch_size}
"""
results = await graph_engine.query(query=query)
if not results:
break
payload = [
{
"start_node": extract_node_data(result["start_node"]),
"start_node_type": result["start_node"]["type"],
"relationship": result["relationship"][1],
"end_node_type": result["end_node"]["type"],
"end_node": extract_node_data(result["end_node"]),
}
for result in results
]
counter += len(payload)
logger.info("Processed %d triplets", counter)
yield payload
offset += triplets_batch_size
async def add_triplets_to_collection(
triplets_batch: List[dict], collection_name: str = "Triplets"
) -> None:
vector_adapter = get_vector_engine()
for triplet_batch in triplets_batch:
data_points = []
for triplet in triplet_batch:
try:
data_point = create_triplet_data_point(triplet)
data_points.append(data_point)
except Exception as e:
raise ValueError(f"Malformed triplet: {triplet}. Error: {e}")
await vector_adapter.create_data_points(collection_name, data_points)
async def get_triplet_embedding_tasks() -> list[Task]:
triplet_embedding_tasks = [
Task(get_triplets_from_graph_store, triplets_batch_size=10),
Task(add_triplets_to_collection),
]
return triplet_embedding_tasks
async def triplet_embedding_postprocessing():
tasks = await get_triplet_embedding_tasks()
async for result in run_tasks_base(tasks, user=await get_default_user(), data=[]):
pass