Compare commits
5 commits
main
...
feature/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0163a7b6f6 | ||
|
|
e0294b38ff | ||
|
|
0bbb7e3781 | ||
|
|
ad07e681b8 | ||
|
|
023f98b33e |
4 changed files with 446 additions and 0 deletions
0
cognee/triplet_embedding_poc/__init__.py
Normal file
0
cognee/triplet_embedding_poc/__init__.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from typing import List
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class TripletNode:
|
||||
"""Graph node representation from triplet data."""
|
||||
|
||||
id: str
|
||||
attributes: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class TripletEdge:
|
||||
"""Graph edge representation from triplet data."""
|
||||
|
||||
node1: TripletNode
|
||||
node2: TripletNode
|
||||
attributes: dict
|
||||
|
||||
|
||||
class EmbeddedTripletCompletionRetriever(GraphCompletionRetriever):
|
||||
"""Retriever that uses embedded triplets from vector collection instead of brute force search.
|
||||
|
||||
Data Conversion Rationale:
|
||||
GraphCompletionRetriever expects graph objects with .node1/.node2 having .id and .attributes,
|
||||
but our triplets are stored as JSON: {"start_node": {"id", "content"}, "relationship", "end_node": {"id", "content"}}.
|
||||
We convert triplet format → TripletNode/TripletEdge objects to match the expected interface.
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str = "Triplets", **kwargs):
|
||||
"""Initialize with configurable collection name."""
|
||||
super().__init__(**kwargs)
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def get_triplets(self, query: str) -> List[TripletEdge]:
|
||||
"""Override parent method to use vector search on triplet collection."""
|
||||
vector_engine = get_vector_engine()
|
||||
search_results = await vector_engine.search(
|
||||
collection_name=self.collection_name, query_text=query, limit=self.top_k
|
||||
)
|
||||
|
||||
triplet_edges = []
|
||||
for search_result in search_results:
|
||||
triplet_edge = self._convert_search_result_to_edge(search_result)
|
||||
triplet_edges.append(triplet_edge)
|
||||
|
||||
return triplet_edges
|
||||
|
||||
def _parse_triplet_payload(self, search_result: ScoredResult) -> dict:
|
||||
"""Extract triplet data from search result payload."""
|
||||
# ScoredResult.payload contains entire DataPoint structure
|
||||
# Actual triplet JSON is nested at payload['payload']
|
||||
triplet_json = search_result.payload["payload"]
|
||||
return json.loads(triplet_json)
|
||||
|
||||
def _create_triplet_node(self, node_data: dict) -> TripletNode:
|
||||
"""Convert triplet node data to graph node format."""
|
||||
return TripletNode(
|
||||
id=node_data["id"],
|
||||
attributes={
|
||||
"text": node_data.get("content", ""),
|
||||
"name": node_data.get("content", "Unnamed Node"),
|
||||
"description": node_data.get("content", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def _create_triplet_edge(
|
||||
self, node1: TripletNode, node2: TripletNode, relationship: str
|
||||
) -> TripletEdge:
|
||||
"""Create graph edge from two nodes and relationship."""
|
||||
return TripletEdge(node1=node1, node2=node2, attributes={"relationship_type": relationship})
|
||||
|
||||
def _convert_search_result_to_edge(self, search_result) -> TripletEdge:
|
||||
"""Main conversion method: ScoredResult → TripletEdge."""
|
||||
# 1. Extract triplet data from search result
|
||||
triplet_data = self._parse_triplet_payload(search_result)
|
||||
|
||||
# 2. Extract triplet components
|
||||
start_node_data = triplet_data["start_node"]
|
||||
relationship = triplet_data["relationship"]
|
||||
end_node_data = triplet_data["end_node"]
|
||||
|
||||
# 3. Create triplet nodes
|
||||
start_node = self._create_triplet_node(start_node_data)
|
||||
end_node = self._create_triplet_node(end_node_data)
|
||||
|
||||
# 4. Create triplet edge
|
||||
triplet_edge = self._create_triplet_edge(start_node, end_node, relationship)
|
||||
|
||||
return triplet_edge
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
retriever = EmbeddedTripletCompletionRetriever(collection_name="Triplets", top_k=5)
|
||||
|
||||
# Test triplet retrieval with CV/resume data
|
||||
query = "machine learning experience Python data scientist"
|
||||
print(f"Searching for: {query}")
|
||||
|
||||
triplets = await retriever.get_triplets(query)
|
||||
print(f"Found {len(triplets)} triplets")
|
||||
|
||||
# Test full context generation
|
||||
context = await retriever.get_context(query)
|
||||
print(f"Generated context:\n{context}")
|
||||
|
||||
# Test another query
|
||||
query2 = "Stanford University education background"
|
||||
print(f"\nSecond search: {query2}")
|
||||
triplets2 = await retriever.get_triplets(query2)
|
||||
print(f"Found {len(triplets2)} triplets")
|
||||
|
||||
asyncio.run(main())
|
||||
188
cognee/triplet_embedding_poc/triplet_embedding_poc_example.py
Normal file
188
cognee/triplet_embedding_poc/triplet_embedding_poc_example.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import setup_logging, INFO
|
||||
from cognee.triplet_embedding_poc.triplet_embedding_postprocessing import (
|
||||
triplet_embedding_postprocessing,
|
||||
)
|
||||
|
||||
job_1 = """
|
||||
CV 1: Relevant
|
||||
Name: Dr. Emily Carter
|
||||
Contact Information:
|
||||
|
||||
Email: emily.carter@example.com
|
||||
Phone: (555) 123-4567
|
||||
Summary:
|
||||
|
||||
Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.
|
||||
|
||||
Education:
|
||||
|
||||
Ph.D. in Computer Science, Stanford University (2014)
|
||||
B.S. in Mathematics, University of California, Berkeley (2010)
|
||||
Experience:
|
||||
|
||||
Senior Data Scientist, InnovateAI Labs (2016 – Present)
|
||||
Led a team in developing machine learning models for natural language processing applications.
|
||||
Implemented deep learning algorithms that improved prediction accuracy by 25%.
|
||||
Collaborated with cross-functional teams to integrate models into cloud-based platforms.
|
||||
Data Scientist, DataWave Analytics (2014 – 2016)
|
||||
Developed predictive models for customer segmentation and churn analysis.
|
||||
Analyzed large datasets using Hadoop and Spark frameworks.
|
||||
Skills:
|
||||
|
||||
Programming Languages: Python, R, SQL
|
||||
Machine Learning: TensorFlow, Keras, Scikit-Learn
|
||||
Big Data Technologies: Hadoop, Spark
|
||||
Data Visualization: Tableau, Matplotlib
|
||||
"""
|
||||
|
||||
job_2 = """
|
||||
CV 2: Relevant
|
||||
Name: Michael Rodriguez
|
||||
Contact Information:
|
||||
|
||||
Email: michael.rodriguez@example.com
|
||||
Phone: (555) 234-5678
|
||||
Summary:
|
||||
|
||||
Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.
|
||||
|
||||
Education:
|
||||
|
||||
M.S. in Data Science, Carnegie Mellon University (2013)
|
||||
B.S. in Computer Science, University of Michigan (2011)
|
||||
Experience:
|
||||
|
||||
Senior Data Scientist, Alpha Analytics (2017 – Present)
|
||||
Developed machine learning models to optimize marketing strategies.
|
||||
Reduced customer acquisition cost by 15% through predictive modeling.
|
||||
Data Scientist, TechInsights (2013 – 2017)
|
||||
Analyzed user behavior data to improve product features.
|
||||
Implemented A/B testing frameworks to evaluate product changes.
|
||||
Skills:
|
||||
|
||||
Programming Languages: Python, Java, SQL
|
||||
Machine Learning: Scikit-Learn, XGBoost
|
||||
Data Visualization: Seaborn, Plotly
|
||||
Databases: MySQL, MongoDB
|
||||
"""
|
||||
|
||||
|
||||
job_3 = """
|
||||
CV 3: Relevant
|
||||
Name: Sarah Nguyen
|
||||
Contact Information:
|
||||
|
||||
Email: sarah.nguyen@example.com
|
||||
Phone: (555) 345-6789
|
||||
Summary:
|
||||
|
||||
Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.
|
||||
|
||||
Education:
|
||||
|
||||
M.S. in Statistics, University of Washington (2014)
|
||||
B.S. in Applied Mathematics, University of Texas at Austin (2012)
|
||||
Experience:
|
||||
|
||||
Data Scientist, QuantumTech (2016 – Present)
|
||||
Designed and implemented machine learning algorithms for financial forecasting.
|
||||
Improved model efficiency by 20% through algorithm optimization.
|
||||
Junior Data Scientist, DataCore Solutions (2014 – 2016)
|
||||
Assisted in developing predictive models for supply chain optimization.
|
||||
Conducted data cleaning and preprocessing on large datasets.
|
||||
Skills:
|
||||
|
||||
Programming Languages: Python, R
|
||||
Machine Learning Frameworks: PyTorch, Scikit-Learn
|
||||
Statistical Analysis: SAS, SPSS
|
||||
Cloud Platforms: AWS, Azure
|
||||
"""
|
||||
|
||||
|
||||
job_4 = """
|
||||
CV 4: Not Relevant
|
||||
Name: David Thompson
|
||||
Contact Information:
|
||||
|
||||
Email: david.thompson@example.com
|
||||
Phone: (555) 456-7890
|
||||
Summary:
|
||||
|
||||
Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.
|
||||
|
||||
Education:
|
||||
|
||||
B.F.A. in Graphic Design, Rhode Island School of Design (2012)
|
||||
Experience:
|
||||
|
||||
Senior Graphic Designer, CreativeWorks Agency (2015 – Present)
|
||||
Led design projects for clients in various industries.
|
||||
Created branding materials that increased client engagement by 30%.
|
||||
Graphic Designer, Visual Innovations (2012 – 2015)
|
||||
Designed marketing collateral, including brochures, logos, and websites.
|
||||
Collaborated with the marketing team to develop cohesive brand strategies.
|
||||
Skills:
|
||||
|
||||
Design Software: Adobe Photoshop, Illustrator, InDesign
|
||||
Web Design: HTML, CSS
|
||||
Specialties: Branding and Identity, Typography
|
||||
"""
|
||||
|
||||
|
||||
job_5 = """
|
||||
CV 5: Not Relevant
|
||||
Name: Jessica Miller
|
||||
Contact Information:
|
||||
|
||||
Email: jessica.miller@example.com
|
||||
Phone: (555) 567-8901
|
||||
Summary:
|
||||
|
||||
Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.
|
||||
|
||||
Education:
|
||||
|
||||
B.A. in Business Administration, University of Southern California (2010)
|
||||
Experience:
|
||||
|
||||
Sales Manager, Global Enterprises (2015 – Present)
|
||||
Managed a sales team of 15 members, achieving a 20% increase in annual revenue.
|
||||
Developed sales strategies that expanded customer base by 25%.
|
||||
Sales Representative, Market Leaders Inc. (2010 – 2015)
|
||||
Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.
|
||||
Skills:
|
||||
|
||||
Sales Strategy and Planning
|
||||
Team Leadership and Development
|
||||
CRM Software: Salesforce, Zoho
|
||||
Negotiation and Relationship Building
|
||||
"""
|
||||
|
||||
|
||||
async def main():
|
||||
pre_graph_creation = True
|
||||
|
||||
if pre_graph_creation:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
text_list = [job_1, job_2, job_3, job_4, job_5]
|
||||
for text in text_list:
|
||||
await cognee.add(text)
|
||||
print(f"Added text: {text[:35]}...")
|
||||
await cognee.cognify()
|
||||
|
||||
await triplet_embedding_postprocessing()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=INFO)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
136
cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py
Normal file
136
cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
from typing import Any, List
|
||||
import uuid
|
||||
import json
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
logger = get_logger("triplet_embedding_poc")
|
||||
|
||||
|
||||
def create_triplet_data_point(triplet: dict) -> "TripletDataPoint":
|
||||
start_node = triplet.get("start_node", None)
|
||||
if start_node:
|
||||
start_node_string = start_node.get("content", None)
|
||||
else:
|
||||
start_node_string = ""
|
||||
|
||||
relationship = triplet.get("relationship", "")
|
||||
|
||||
end_node = triplet.get("end_node", None)
|
||||
if end_node:
|
||||
end_node_string = end_node.get("content", None)
|
||||
else:
|
||||
end_node_string = ""
|
||||
|
||||
start_node_type = triplet.get("start_node_type", "")
|
||||
end_node_type = triplet.get("end_node_type", "")
|
||||
|
||||
triplet_str = (
|
||||
start_node_string
|
||||
+ "-"
|
||||
+ start_node_type
|
||||
+ "-"
|
||||
+ relationship
|
||||
+ "-"
|
||||
+ end_node_string
|
||||
+ "-"
|
||||
+ end_node_type
|
||||
)
|
||||
|
||||
triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str)
|
||||
|
||||
return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str)
|
||||
|
||||
|
||||
class TripletDataPoint(DataPoint):
|
||||
"""DataPoint for storing graph triplets with embedded text representation."""
|
||||
|
||||
payload: str
|
||||
text: str
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
def extract_node_data(node_dict):
|
||||
"""Extract relevant data from a node dictionary."""
|
||||
result = {"id": node_dict["id"]}
|
||||
if "metadata" not in node_dict:
|
||||
return result
|
||||
metadata = json.loads(node_dict["metadata"])
|
||||
if "index_fields" not in metadata or not metadata["index_fields"]:
|
||||
return result
|
||||
index_field_name = metadata["index_fields"][0] # Always one entry
|
||||
if index_field_name not in node_dict:
|
||||
return result
|
||||
result["content"] = node_dict[index_field_name]
|
||||
result["index_field_name"] = index_field_name
|
||||
return result
|
||||
|
||||
|
||||
async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
|
||||
graph_engine = await get_graph_engine()
|
||||
counter = 0
|
||||
offset = 0
|
||||
while True:
|
||||
query = f"""
|
||||
MATCH (start_node)-[relationship]->(end_node)
|
||||
RETURN start_node, relationship, end_node
|
||||
SKIP {offset} LIMIT {triplets_batch_size}
|
||||
"""
|
||||
results = await graph_engine.query(query=query)
|
||||
if not results:
|
||||
break
|
||||
payload = [
|
||||
{
|
||||
"start_node": extract_node_data(result["start_node"]),
|
||||
"start_node_type": result["start_node"]["type"],
|
||||
"relationship": result["relationship"][1],
|
||||
"end_node_type": result["end_node"]["type"],
|
||||
"end_node": extract_node_data(result["end_node"]),
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
|
||||
counter += len(payload)
|
||||
logger.info("Processed %d triplets", counter)
|
||||
yield payload
|
||||
offset += triplets_batch_size
|
||||
|
||||
|
||||
async def add_triplets_to_collection(
|
||||
triplets_batch: List[dict], collection_name: str = "Triplets"
|
||||
) -> None:
|
||||
vector_adapter = get_vector_engine()
|
||||
|
||||
for triplet_batch in triplets_batch:
|
||||
data_points = []
|
||||
for triplet in triplet_batch:
|
||||
try:
|
||||
data_point = create_triplet_data_point(triplet)
|
||||
data_points.append(data_point)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Malformed triplet: {triplet}. Error: {e}")
|
||||
|
||||
await vector_adapter.create_data_points(collection_name, data_points)
|
||||
|
||||
|
||||
async def get_triplet_embedding_tasks() -> list[Task]:
|
||||
triplet_embedding_tasks = [
|
||||
Task(get_triplets_from_graph_store, triplets_batch_size=10),
|
||||
Task(add_triplets_to_collection),
|
||||
]
|
||||
|
||||
return triplet_embedding_tasks
|
||||
|
||||
|
||||
async def triplet_embedding_postprocessing():
|
||||
tasks = await get_triplet_embedding_tasks()
|
||||
|
||||
async for result in run_tasks_base(tasks, user=await get_default_user(), data=[]):
|
||||
pass
|
||||
Loading…
Add table
Reference in a new issue