Compare commits
5 commits
main
...
feature/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0163a7b6f6 | ||
|
|
e0294b38ff | ||
|
|
0bbb7e3781 | ||
|
|
ad07e681b8 | ||
|
|
023f98b33e |
4 changed files with 446 additions and 0 deletions
0
cognee/triplet_embedding_poc/__init__.py
Normal file
0
cognee/triplet_embedding_poc/__init__.py
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
from typing import List
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TripletNode:
|
||||||
|
"""Graph node representation from triplet data."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
attributes: dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TripletEdge:
|
||||||
|
"""Graph edge representation from triplet data."""
|
||||||
|
|
||||||
|
node1: TripletNode
|
||||||
|
node2: TripletNode
|
||||||
|
attributes: dict
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddedTripletCompletionRetriever(GraphCompletionRetriever):
|
||||||
|
"""Retriever that uses embedded triplets from vector collection instead of brute force search.
|
||||||
|
|
||||||
|
Data Conversion Rationale:
|
||||||
|
GraphCompletionRetriever expects graph objects with .node1/.node2 having .id and .attributes,
|
||||||
|
but our triplets are stored as JSON: {"start_node": {"id", "content"}, "relationship", "end_node": {"id", "content"}}.
|
||||||
|
We convert triplet format → TripletNode/TripletEdge objects to match the expected interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str = "Triplets", **kwargs):
|
||||||
|
"""Initialize with configurable collection name."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
async def get_triplets(self, query: str) -> List[TripletEdge]:
|
||||||
|
"""Override parent method to use vector search on triplet collection."""
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
search_results = await vector_engine.search(
|
||||||
|
collection_name=self.collection_name, query_text=query, limit=self.top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
triplet_edges = []
|
||||||
|
for search_result in search_results:
|
||||||
|
triplet_edge = self._convert_search_result_to_edge(search_result)
|
||||||
|
triplet_edges.append(triplet_edge)
|
||||||
|
|
||||||
|
return triplet_edges
|
||||||
|
|
||||||
|
def _parse_triplet_payload(self, search_result: ScoredResult) -> dict:
|
||||||
|
"""Extract triplet data from search result payload."""
|
||||||
|
# ScoredResult.payload contains entire DataPoint structure
|
||||||
|
# Actual triplet JSON is nested at payload['payload']
|
||||||
|
triplet_json = search_result.payload["payload"]
|
||||||
|
return json.loads(triplet_json)
|
||||||
|
|
||||||
|
def _create_triplet_node(self, node_data: dict) -> TripletNode:
|
||||||
|
"""Convert triplet node data to graph node format."""
|
||||||
|
return TripletNode(
|
||||||
|
id=node_data["id"],
|
||||||
|
attributes={
|
||||||
|
"text": node_data.get("content", ""),
|
||||||
|
"name": node_data.get("content", "Unnamed Node"),
|
||||||
|
"description": node_data.get("content", ""),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_triplet_edge(
|
||||||
|
self, node1: TripletNode, node2: TripletNode, relationship: str
|
||||||
|
) -> TripletEdge:
|
||||||
|
"""Create graph edge from two nodes and relationship."""
|
||||||
|
return TripletEdge(node1=node1, node2=node2, attributes={"relationship_type": relationship})
|
||||||
|
|
||||||
|
def _convert_search_result_to_edge(self, search_result) -> TripletEdge:
|
||||||
|
"""Main conversion method: ScoredResult → TripletEdge."""
|
||||||
|
# 1. Extract triplet data from search result
|
||||||
|
triplet_data = self._parse_triplet_payload(search_result)
|
||||||
|
|
||||||
|
# 2. Extract triplet components
|
||||||
|
start_node_data = triplet_data["start_node"]
|
||||||
|
relationship = triplet_data["relationship"]
|
||||||
|
end_node_data = triplet_data["end_node"]
|
||||||
|
|
||||||
|
# 3. Create triplet nodes
|
||||||
|
start_node = self._create_triplet_node(start_node_data)
|
||||||
|
end_node = self._create_triplet_node(end_node_data)
|
||||||
|
|
||||||
|
# 4. Create triplet edge
|
||||||
|
triplet_edge = self._create_triplet_edge(start_node, end_node, relationship)
|
||||||
|
|
||||||
|
return triplet_edge
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
retriever = EmbeddedTripletCompletionRetriever(collection_name="Triplets", top_k=5)
|
||||||
|
|
||||||
|
# Test triplet retrieval with CV/resume data
|
||||||
|
query = "machine learning experience Python data scientist"
|
||||||
|
print(f"Searching for: {query}")
|
||||||
|
|
||||||
|
triplets = await retriever.get_triplets(query)
|
||||||
|
print(f"Found {len(triplets)} triplets")
|
||||||
|
|
||||||
|
# Test full context generation
|
||||||
|
context = await retriever.get_context(query)
|
||||||
|
print(f"Generated context:\n{context}")
|
||||||
|
|
||||||
|
# Test another query
|
||||||
|
query2 = "Stanford University education background"
|
||||||
|
print(f"\nSecond search: {query2}")
|
||||||
|
triplets2 = await retriever.get_triplets(query2)
|
||||||
|
print(f"Found {len(triplets2)} triplets")
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
188
cognee/triplet_embedding_poc/triplet_embedding_poc_example.py
Normal file
188
cognee/triplet_embedding_poc/triplet_embedding_poc_example.py
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
import asyncio
|
||||||
|
import cognee
|
||||||
|
from cognee.shared.logging_utils import setup_logging, INFO
|
||||||
|
from cognee.triplet_embedding_poc.triplet_embedding_postprocessing import (
|
||||||
|
triplet_embedding_postprocessing,
|
||||||
|
)
|
||||||
|
|
||||||
|
job_1 = """
|
||||||
|
CV 1: Relevant
|
||||||
|
Name: Dr. Emily Carter
|
||||||
|
Contact Information:
|
||||||
|
|
||||||
|
Email: emily.carter@example.com
|
||||||
|
Phone: (555) 123-4567
|
||||||
|
Summary:
|
||||||
|
|
||||||
|
Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.
|
||||||
|
|
||||||
|
Education:
|
||||||
|
|
||||||
|
Ph.D. in Computer Science, Stanford University (2014)
|
||||||
|
B.S. in Mathematics, University of California, Berkeley (2010)
|
||||||
|
Experience:
|
||||||
|
|
||||||
|
Senior Data Scientist, InnovateAI Labs (2016 – Present)
|
||||||
|
Led a team in developing machine learning models for natural language processing applications.
|
||||||
|
Implemented deep learning algorithms that improved prediction accuracy by 25%.
|
||||||
|
Collaborated with cross-functional teams to integrate models into cloud-based platforms.
|
||||||
|
Data Scientist, DataWave Analytics (2014 – 2016)
|
||||||
|
Developed predictive models for customer segmentation and churn analysis.
|
||||||
|
Analyzed large datasets using Hadoop and Spark frameworks.
|
||||||
|
Skills:
|
||||||
|
|
||||||
|
Programming Languages: Python, R, SQL
|
||||||
|
Machine Learning: TensorFlow, Keras, Scikit-Learn
|
||||||
|
Big Data Technologies: Hadoop, Spark
|
||||||
|
Data Visualization: Tableau, Matplotlib
|
||||||
|
"""
|
||||||
|
|
||||||
|
job_2 = """
|
||||||
|
CV 2: Relevant
|
||||||
|
Name: Michael Rodriguez
|
||||||
|
Contact Information:
|
||||||
|
|
||||||
|
Email: michael.rodriguez@example.com
|
||||||
|
Phone: (555) 234-5678
|
||||||
|
Summary:
|
||||||
|
|
||||||
|
Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.
|
||||||
|
|
||||||
|
Education:
|
||||||
|
|
||||||
|
M.S. in Data Science, Carnegie Mellon University (2013)
|
||||||
|
B.S. in Computer Science, University of Michigan (2011)
|
||||||
|
Experience:
|
||||||
|
|
||||||
|
Senior Data Scientist, Alpha Analytics (2017 – Present)
|
||||||
|
Developed machine learning models to optimize marketing strategies.
|
||||||
|
Reduced customer acquisition cost by 15% through predictive modeling.
|
||||||
|
Data Scientist, TechInsights (2013 – 2017)
|
||||||
|
Analyzed user behavior data to improve product features.
|
||||||
|
Implemented A/B testing frameworks to evaluate product changes.
|
||||||
|
Skills:
|
||||||
|
|
||||||
|
Programming Languages: Python, Java, SQL
|
||||||
|
Machine Learning: Scikit-Learn, XGBoost
|
||||||
|
Data Visualization: Seaborn, Plotly
|
||||||
|
Databases: MySQL, MongoDB
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
job_3 = """
|
||||||
|
CV 3: Relevant
|
||||||
|
Name: Sarah Nguyen
|
||||||
|
Contact Information:
|
||||||
|
|
||||||
|
Email: sarah.nguyen@example.com
|
||||||
|
Phone: (555) 345-6789
|
||||||
|
Summary:
|
||||||
|
|
||||||
|
Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.
|
||||||
|
|
||||||
|
Education:
|
||||||
|
|
||||||
|
M.S. in Statistics, University of Washington (2014)
|
||||||
|
B.S. in Applied Mathematics, University of Texas at Austin (2012)
|
||||||
|
Experience:
|
||||||
|
|
||||||
|
Data Scientist, QuantumTech (2016 – Present)
|
||||||
|
Designed and implemented machine learning algorithms for financial forecasting.
|
||||||
|
Improved model efficiency by 20% through algorithm optimization.
|
||||||
|
Junior Data Scientist, DataCore Solutions (2014 – 2016)
|
||||||
|
Assisted in developing predictive models for supply chain optimization.
|
||||||
|
Conducted data cleaning and preprocessing on large datasets.
|
||||||
|
Skills:
|
||||||
|
|
||||||
|
Programming Languages: Python, R
|
||||||
|
Machine Learning Frameworks: PyTorch, Scikit-Learn
|
||||||
|
Statistical Analysis: SAS, SPSS
|
||||||
|
Cloud Platforms: AWS, Azure
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
job_4 = """
|
||||||
|
CV 4: Not Relevant
|
||||||
|
Name: David Thompson
|
||||||
|
Contact Information:
|
||||||
|
|
||||||
|
Email: david.thompson@example.com
|
||||||
|
Phone: (555) 456-7890
|
||||||
|
Summary:
|
||||||
|
|
||||||
|
Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.
|
||||||
|
|
||||||
|
Education:
|
||||||
|
|
||||||
|
B.F.A. in Graphic Design, Rhode Island School of Design (2012)
|
||||||
|
Experience:
|
||||||
|
|
||||||
|
Senior Graphic Designer, CreativeWorks Agency (2015 – Present)
|
||||||
|
Led design projects for clients in various industries.
|
||||||
|
Created branding materials that increased client engagement by 30%.
|
||||||
|
Graphic Designer, Visual Innovations (2012 – 2015)
|
||||||
|
Designed marketing collateral, including brochures, logos, and websites.
|
||||||
|
Collaborated with the marketing team to develop cohesive brand strategies.
|
||||||
|
Skills:
|
||||||
|
|
||||||
|
Design Software: Adobe Photoshop, Illustrator, InDesign
|
||||||
|
Web Design: HTML, CSS
|
||||||
|
Specialties: Branding and Identity, Typography
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
job_5 = """
|
||||||
|
CV 5: Not Relevant
|
||||||
|
Name: Jessica Miller
|
||||||
|
Contact Information:
|
||||||
|
|
||||||
|
Email: jessica.miller@example.com
|
||||||
|
Phone: (555) 567-8901
|
||||||
|
Summary:
|
||||||
|
|
||||||
|
Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.
|
||||||
|
|
||||||
|
Education:
|
||||||
|
|
||||||
|
B.A. in Business Administration, University of Southern California (2010)
|
||||||
|
Experience:
|
||||||
|
|
||||||
|
Sales Manager, Global Enterprises (2015 – Present)
|
||||||
|
Managed a sales team of 15 members, achieving a 20% increase in annual revenue.
|
||||||
|
Developed sales strategies that expanded customer base by 25%.
|
||||||
|
Sales Representative, Market Leaders Inc. (2010 – 2015)
|
||||||
|
Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.
|
||||||
|
Skills:
|
||||||
|
|
||||||
|
Sales Strategy and Planning
|
||||||
|
Team Leadership and Development
|
||||||
|
CRM Software: Salesforce, Zoho
|
||||||
|
Negotiation and Relationship Building
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
pre_graph_creation = True
|
||||||
|
|
||||||
|
if pre_graph_creation:
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
text_list = [job_1, job_2, job_3, job_4, job_5]
|
||||||
|
for text in text_list:
|
||||||
|
await cognee.add(text)
|
||||||
|
print(f"Added text: {text[:35]}...")
|
||||||
|
await cognee.cognify()
|
||||||
|
|
||||||
|
await triplet_embedding_postprocessing()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = setup_logging(log_level=INFO)
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
136
cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py
Normal file
136
cognee/triplet_embedding_poc/triplet_embedding_postprocessing.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
from typing import Any, List
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||||
|
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("triplet_embedding_poc")
|
||||||
|
|
||||||
|
|
||||||
|
def create_triplet_data_point(triplet: dict) -> "TripletDataPoint":
|
||||||
|
start_node = triplet.get("start_node", None)
|
||||||
|
if start_node:
|
||||||
|
start_node_string = start_node.get("content", None)
|
||||||
|
else:
|
||||||
|
start_node_string = ""
|
||||||
|
|
||||||
|
relationship = triplet.get("relationship", "")
|
||||||
|
|
||||||
|
end_node = triplet.get("end_node", None)
|
||||||
|
if end_node:
|
||||||
|
end_node_string = end_node.get("content", None)
|
||||||
|
else:
|
||||||
|
end_node_string = ""
|
||||||
|
|
||||||
|
start_node_type = triplet.get("start_node_type", "")
|
||||||
|
end_node_type = triplet.get("end_node_type", "")
|
||||||
|
|
||||||
|
triplet_str = (
|
||||||
|
start_node_string
|
||||||
|
+ "-"
|
||||||
|
+ start_node_type
|
||||||
|
+ "-"
|
||||||
|
+ relationship
|
||||||
|
+ "-"
|
||||||
|
+ end_node_string
|
||||||
|
+ "-"
|
||||||
|
+ end_node_type
|
||||||
|
)
|
||||||
|
|
||||||
|
triplet_uuid = uuid.uuid5(uuid.NAMESPACE_OID, name=triplet_str)
|
||||||
|
|
||||||
|
return TripletDataPoint(id=triplet_uuid, payload=json.dumps(triplet), text=triplet_str)
|
||||||
|
|
||||||
|
|
||||||
|
class TripletDataPoint(DataPoint):
|
||||||
|
"""DataPoint for storing graph triplets with embedded text representation."""
|
||||||
|
|
||||||
|
payload: str
|
||||||
|
text: str
|
||||||
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_node_data(node_dict):
|
||||||
|
"""Extract relevant data from a node dictionary."""
|
||||||
|
result = {"id": node_dict["id"]}
|
||||||
|
if "metadata" not in node_dict:
|
||||||
|
return result
|
||||||
|
metadata = json.loads(node_dict["metadata"])
|
||||||
|
if "index_fields" not in metadata or not metadata["index_fields"]:
|
||||||
|
return result
|
||||||
|
index_field_name = metadata["index_fields"][0] # Always one entry
|
||||||
|
if index_field_name not in node_dict:
|
||||||
|
return result
|
||||||
|
result["content"] = node_dict[index_field_name]
|
||||||
|
result["index_field_name"] = index_field_name
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def get_triplets_from_graph_store(data, triplets_batch_size=10) -> Any:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
counter = 0
|
||||||
|
offset = 0
|
||||||
|
while True:
|
||||||
|
query = f"""
|
||||||
|
MATCH (start_node)-[relationship]->(end_node)
|
||||||
|
RETURN start_node, relationship, end_node
|
||||||
|
SKIP {offset} LIMIT {triplets_batch_size}
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query=query)
|
||||||
|
if not results:
|
||||||
|
break
|
||||||
|
payload = [
|
||||||
|
{
|
||||||
|
"start_node": extract_node_data(result["start_node"]),
|
||||||
|
"start_node_type": result["start_node"]["type"],
|
||||||
|
"relationship": result["relationship"][1],
|
||||||
|
"end_node_type": result["end_node"]["type"],
|
||||||
|
"end_node": extract_node_data(result["end_node"]),
|
||||||
|
}
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
|
||||||
|
counter += len(payload)
|
||||||
|
logger.info("Processed %d triplets", counter)
|
||||||
|
yield payload
|
||||||
|
offset += triplets_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
async def add_triplets_to_collection(
|
||||||
|
triplets_batch: List[dict], collection_name: str = "Triplets"
|
||||||
|
) -> None:
|
||||||
|
vector_adapter = get_vector_engine()
|
||||||
|
|
||||||
|
for triplet_batch in triplets_batch:
|
||||||
|
data_points = []
|
||||||
|
for triplet in triplet_batch:
|
||||||
|
try:
|
||||||
|
data_point = create_triplet_data_point(triplet)
|
||||||
|
data_points.append(data_point)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Malformed triplet: {triplet}. Error: {e}")
|
||||||
|
|
||||||
|
await vector_adapter.create_data_points(collection_name, data_points)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_triplet_embedding_tasks() -> list[Task]:
|
||||||
|
triplet_embedding_tasks = [
|
||||||
|
Task(get_triplets_from_graph_store, triplets_batch_size=10),
|
||||||
|
Task(add_triplets_to_collection),
|
||||||
|
]
|
||||||
|
|
||||||
|
return triplet_embedding_tasks
|
||||||
|
|
||||||
|
|
||||||
|
async def triplet_embedding_postprocessing():
|
||||||
|
tasks = await get_triplet_embedding_tasks()
|
||||||
|
|
||||||
|
async for result in run_tasks_base(tasks, user=await get_default_user(), data=[]):
|
||||||
|
pass
|
||||||
Loading…
Add table
Reference in a new issue