Merge remote-tracking branch 'origin/dev' into feature/cog-2078-cognee-ui-refactor
|
|
@ -1,28 +0,0 @@
|
||||||
'''
|
|
||||||
Given a string, find the length of the longest substring without repeating characters.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
Given "abcabcbb", the answer is "abc", which the length is 3.
|
|
||||||
|
|
||||||
Given "bbbbb", the answer is "b", with the length of 1.
|
|
||||||
|
|
||||||
Given "pwwkew", the answer is "wke", with the length of 3. Note that the answer must be a substring, "pwke" is a subsequence and not a substring.
|
|
||||||
'''
|
|
||||||
|
|
||||||
class Solution(object):
|
|
||||||
def lengthOfLongestSubstring(self, s):
|
|
||||||
"""
|
|
||||||
:type s: str
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
mapSet = {}
|
|
||||||
start, result = 0, 0
|
|
||||||
|
|
||||||
for end in range(len(s)):
|
|
||||||
if s[end] in mapSet:
|
|
||||||
start = max(mapSet[s[end]], start)
|
|
||||||
result = max(result, end-start+1)
|
|
||||||
mapSet[s[end]] = end+1
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
@ -35,11 +35,11 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
|
||||||
<p align="center">
|
<p align="center">
|
||||||
🌐 Available Languages
|
🌐 Available Languages
|
||||||
:
|
:
|
||||||
<a href="community/README.pt.md">🇵🇹 Português</a>
|
<a href="assets/community/README.pt.md">🇵🇹 Português</a>
|
||||||
·
|
·
|
||||||
<a href="community/README.zh.md">🇨🇳 [中文]</a>
|
<a href="assets/community/README.zh.md">🇨🇳 [中文]</a>
|
||||||
·
|
·
|
||||||
<a href="community/README.ru.md">🇷🇺 Русский</a>
|
<a href="assets/community/README.ru.md">🇷🇺 Русский</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
Before Width: | Height: | Size: 262 KiB After Width: | Height: | Size: 262 KiB |
|
Before Width: | Height: | Size: 181 KiB After Width: | Height: | Size: 181 KiB |
|
Before Width: | Height: | Size: 603 KiB After Width: | Height: | Size: 603 KiB |
|
Before Width: | Height: | Size: 890 KiB After Width: | Height: | Size: 890 KiB |
|
|
@ -8,7 +8,7 @@ requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# For local cognee repo usage remove comment bellow and add absolute path to cognee
|
# For local cognee repo usage remove comment bellow and add absolute path to cognee
|
||||||
#"cognee[postgres,codegraph,gemini,huggingface] @ file:/Users/<username>/Desktop/cognee",
|
#"cognee[postgres,codegraph,gemini,huggingface] @ file:/Users/<username>/Desktop/cognee",
|
||||||
"cognee[postgres,codegraph,gemini,huggingface,docs]==0.1.40",
|
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.1.40",
|
||||||
"fastmcp>=1.0",
|
"fastmcp>=1.0",
|
||||||
"mcp==1.5.0",
|
"mcp==1.5.0",
|
||||||
"uv>=0.6.3",
|
"uv>=0.6.3",
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Any
|
||||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
GraphSummaryCompletionRetriever,
|
GraphSummaryCompletionRetriever,
|
||||||
|
|
@ -8,8 +12,10 @@ from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
retriever_options: Dict[str, BaseRetriever] = {
|
retriever_options: Dict[str, Any] = {
|
||||||
"cognee_graph_completion": GraphCompletionRetriever,
|
"cognee_graph_completion": GraphCompletionRetriever,
|
||||||
|
"cognee_graph_completion_cot": GraphCompletionCotRetriever,
|
||||||
|
"cognee_graph_completion_context_extension": GraphCompletionContextExtensionRetriever,
|
||||||
"cognee_completion": CompletionRetriever,
|
"cognee_completion": CompletionRetriever,
|
||||||
"graph_summary_completion": GraphSummaryCompletionRetriever,
|
"graph_summary_completion": GraphSummaryCompletionRetriever,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,7 @@ class EvalConfig(BaseSettings):
|
||||||
|
|
||||||
# Question answering params
|
# Question answering params
|
||||||
answering_questions: bool = True
|
answering_questions: bool = True
|
||||||
qa_engine: str = (
|
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||||
"cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluation params
|
# Evaluation params
|
||||||
evaluating_answers: bool = True
|
evaluating_answers: bool = True
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,8 @@ def run():
|
||||||
# Streamlit Dashboard Application Logic
|
# Streamlit Dashboard Application Logic
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
def main():
|
def main():
|
||||||
|
metrics_volume.reload()
|
||||||
|
|
||||||
st.set_page_config(page_title="Metrics Dashboard", layout="wide")
|
st.set_page_config(page_title="Metrics Dashboard", layout="wide")
|
||||||
st.title("📊 Cognee Evaluations Dashboard")
|
st.title("📊 Cognee Evaluations Dashboard")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,
|
||||||
|
to collect the missing piece of information needed to fully answer the user’s original query.
|
||||||
|
Respond with the question only (no extra text, no punctuation beyond what’s needed).
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer.
|
||||||
|
Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks
|
||||||
|
|
||||||
|
<QUERY>
|
||||||
|
`{{ query}}`
|
||||||
|
</QUERY>
|
||||||
|
|
||||||
|
<ANSWER>
|
||||||
|
`{{ answer }}`
|
||||||
|
</ANSWER>
|
||||||
|
|
||||||
|
<REASONING>
|
||||||
|
`{{ reasoning }}`
|
||||||
|
</REASONING>
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
You are a helpful agent who are allowed to use only the provided question answer and context.
|
||||||
|
I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context.
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
<QUESTION>
|
||||||
|
`{{ query}}`
|
||||||
|
</QUESTION>
|
||||||
|
|
||||||
|
<ANSWER>
|
||||||
|
`{{ answer }}`
|
||||||
|
</ANSWER>
|
||||||
|
|
||||||
|
<CONTEXT>
|
||||||
|
`{{ context }}`
|
||||||
|
</CONTEXT>
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
top_k: Optional[int] = 5,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
user_prompt_path=user_prompt_path,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_completion(
|
||||||
|
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
||||||
|
) -> List[str]:
|
||||||
|
triplets = []
|
||||||
|
|
||||||
|
if context is None:
|
||||||
|
triplets += await self.get_triplets(query)
|
||||||
|
context = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
|
round_idx = 1
|
||||||
|
|
||||||
|
while round_idx <= context_extension_rounds:
|
||||||
|
prev_size = len(triplets)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context extension: round {round_idx} - generating next graph locational query."
|
||||||
|
)
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
triplets += await self.get_triplets(completion)
|
||||||
|
triplets = list(set(triplets))
|
||||||
|
context = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
|
num_triplets = len(triplets)
|
||||||
|
|
||||||
|
if num_triplets == prev_size:
|
||||||
|
logger.info(
|
||||||
|
f"Context extension: round {round_idx} – no new triplets found; stopping early."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context extension: round {round_idx} - "
|
||||||
|
f"number of unique retrieved triplets: {num_triplets}"
|
||||||
|
)
|
||||||
|
|
||||||
|
round_idx += 1
|
||||||
|
|
||||||
|
answer = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [answer]
|
||||||
84
cognee/modules/retrieval/graph_completion_cot_retriever.py
Normal file
|
|
@ -0,0 +1,84 @@
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
validation_user_prompt_path: str = "cot_validation_user_prompt.txt",
|
||||||
|
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
||||||
|
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
||||||
|
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
||||||
|
top_k: Optional[int] = 5,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
user_prompt_path=user_prompt_path,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
self.validation_system_prompt_path = validation_system_prompt_path
|
||||||
|
self.validation_user_prompt_path = validation_user_prompt_path
|
||||||
|
self.followup_system_prompt_path = followup_system_prompt_path
|
||||||
|
self.followup_user_prompt_path = followup_user_prompt_path
|
||||||
|
|
||||||
|
async def get_completion(
|
||||||
|
self, query: str, context: Optional[Any] = None, max_iter=4
|
||||||
|
) -> List[str]:
|
||||||
|
llm_client = get_llm_client()
|
||||||
|
followup_question = ""
|
||||||
|
triplets = []
|
||||||
|
answer = [""]
|
||||||
|
|
||||||
|
for round_idx in range(max_iter + 1):
|
||||||
|
if round_idx == 0:
|
||||||
|
if context is None:
|
||||||
|
context = await self.get_context(query)
|
||||||
|
else:
|
||||||
|
triplets += await self.get_triplets(followup_question)
|
||||||
|
context = await self.resolve_edges_to_text(list(set(triplets)))
|
||||||
|
|
||||||
|
answer = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}")
|
||||||
|
if round_idx < max_iter:
|
||||||
|
valid_args = {"query": query, "answer": answer, "context": context}
|
||||||
|
valid_user_prompt = render_prompt(
|
||||||
|
filename=self.validation_user_prompt_path, context=valid_args
|
||||||
|
)
|
||||||
|
valid_system_prompt = read_query_prompt(
|
||||||
|
prompt_file_name=self.validation_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning = await llm_client.acreate_structured_output(
|
||||||
|
text_input=valid_user_prompt,
|
||||||
|
system_prompt=valid_system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
followup_args = {"query": query, "answer": answer, "reasoning": reasoning}
|
||||||
|
followup_prompt = render_prompt(
|
||||||
|
filename=self.followup_user_prompt_path, context=followup_args
|
||||||
|
)
|
||||||
|
followup_system = read_query_prompt(
|
||||||
|
prompt_file_name=self.followup_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
followup_question = await llm_client.acreate_structured_output(
|
||||||
|
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [answer]
|
||||||
|
|
@ -11,6 +11,10 @@ from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionR
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
GraphSummaryCompletionRetriever,
|
GraphSummaryCompletionRetriever,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||||
|
|
@ -19,7 +23,7 @@ from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from ..operations import log_query, log_result
|
from cognee.modules.search.operations import log_query, log_result
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
|
|
@ -70,6 +74,14 @@ async def specific_search(
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
|
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
).get_completion,
|
||||||
|
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
).get_completion,
|
||||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path, top_k=top_k
|
system_prompt_path=system_prompt_path, top_k=top_k
|
||||||
).get_completion,
|
).get_completion,
|
||||||
|
|
|
||||||
|
|
@ -11,3 +11,5 @@ class SearchType(Enum):
|
||||||
CODE = "CODE"
|
CODE = "CODE"
|
||||||
CYPHER = "CYPHER"
|
CYPHER = "CYPHER"
|
||||||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||||
|
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||||
|
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,185 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphCompletionRetriever:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_extension_context_simple(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Canva?")
|
||||||
|
|
||||||
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_extension_context_complex(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class Car(DataPoint):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
year: int
|
||||||
|
|
||||||
|
class Location(DataPoint):
|
||||||
|
country: str
|
||||||
|
city: str
|
||||||
|
|
||||||
|
class Home(DataPoint):
|
||||||
|
location: Location
|
||||||
|
rooms: int
|
||||||
|
sqm: int
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
owns: Optional[list[Union[Car, Home]]] = None
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
|
||||||
|
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||||
|
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||||
|
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person2.owns = [
|
||||||
|
Car(brand="Tesla", model="Model S", year=2021),
|
||||||
|
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||||
|
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
print(context)
|
||||||
|
|
||||||
|
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(DatabaseNotCreatedError):
|
||||||
|
await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
assert context == "", "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from asyncio import run
|
||||||
|
|
||||||
|
test = TestGraphCompletionRetriever()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await test.test_graph_completion_context_simple()
|
||||||
|
await test.test_graph_completion_context_complex()
|
||||||
|
await test.test_get_graph_completion_context_on_empty_graph()
|
||||||
|
|
||||||
|
run(main())
|
||||||
|
|
@ -0,0 +1,183 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphCompletionRetriever:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_cot_context_simple(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Canva?")
|
||||||
|
|
||||||
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_completion_cot_context_complex(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
|
class Car(DataPoint):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
year: int
|
||||||
|
|
||||||
|
class Location(DataPoint):
|
||||||
|
country: str
|
||||||
|
city: str
|
||||||
|
|
||||||
|
class Home(DataPoint):
|
||||||
|
location: Location
|
||||||
|
rooms: int
|
||||||
|
sqm: int
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
owns: Optional[list[Union[Car, Home]]] = None
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
company2 = Company(name="Canva")
|
||||||
|
|
||||||
|
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||||
|
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||||
|
|
||||||
|
person2 = Person(name="Ike Loma", works_for=company1)
|
||||||
|
person2.owns = [
|
||||||
|
Car(brand="Tesla", model="Model S", year=2021),
|
||||||
|
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
person3 = Person(name="Jason Statham", works_for=company1)
|
||||||
|
|
||||||
|
person4 = Person(name="Mike Broski", works_for=company2)
|
||||||
|
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||||
|
|
||||||
|
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||||
|
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||||
|
|
||||||
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
print(context)
|
||||||
|
|
||||||
|
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
with pytest.raises(DatabaseNotCreatedError):
|
||||||
|
await retriever.get_context("Who works at Figma?")
|
||||||
|
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
|
assert context == "", "Context should be empty on an empty graph"
|
||||||
|
|
||||||
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
|
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||||
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
|
"Answer must contain only non-empty strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from asyncio import run
|
||||||
|
|
||||||
|
test = TestGraphCompletionRetriever()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await test.test_graph_completion_context_simple()
|
||||||
|
await test.test_graph_completion_context_complex()
|
||||||
|
await test.test_get_graph_completion_context_on_empty_graph()
|
||||||
|
|
||||||
|
run(main())
|
||||||
|
Before Width: | Height: | Size: 10 KiB After Width: | Height: | Size: 10 KiB |
|
|
@ -21,11 +21,11 @@ async def main():
|
||||||
# and description of these files
|
# and description of these files
|
||||||
mp3_file_path = os.path.join(
|
mp3_file_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent.parent.parent,
|
pathlib.Path(__file__).parent.parent.parent,
|
||||||
".data/multimedia/text_to_speech.mp3",
|
"examples/data/multimedia/text_to_speech.mp3",
|
||||||
)
|
)
|
||||||
png_file_path = os.path.join(
|
png_file_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent.parent.parent,
|
pathlib.Path(__file__).parent.parent.parent,
|
||||||
".data/multimedia/example.png",
|
"examples/data/multimedia/example.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the files, and make it available for cognify
|
# Add the files, and make it available for cognify
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"cell_type": "code",
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
"execution_count": null,
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import pathlib\n",
|
"import pathlib\n",
|
||||||
|
|
@ -34,12 +34,12 @@
|
||||||
"mp3_file_path = os.path.join(\n",
|
"mp3_file_path = os.path.join(\n",
|
||||||
" os.path.abspath(\"\"),\n",
|
" os.path.abspath(\"\"),\n",
|
||||||
" \"../\",\n",
|
" \"../\",\n",
|
||||||
" \".data/multimedia/text_to_speech.mp3\",\n",
|
" \"examples/data/multimedia/text_to_speech.mp3\",\n",
|
||||||
")\n",
|
")\n",
|
||||||
"png_file_path = os.path.join(\n",
|
"png_file_path = os.path.join(\n",
|
||||||
" os.path.abspath(\"\"),\n",
|
" os.path.abspath(\"\"),\n",
|
||||||
" \"../\",\n",
|
" \"../\",\n",
|
||||||
" \".data/multimedia/example.png\",\n",
|
" \"examples/data/multimedia/example.png\",\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
import statistics
|
|
||||||
import time
|
|
||||||
import tracemalloc
|
|
||||||
from typing import Any, Callable, Dict
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Benchmark a function for memory usage and computational performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: Function to benchmark
|
|
||||||
*args: Arguments to pass to the function
|
|
||||||
num_runs: Number of times to run the benchmark
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing benchmark metrics
|
|
||||||
"""
|
|
||||||
execution_times = []
|
|
||||||
peak_memory_usages = []
|
|
||||||
cpu_percentages = []
|
|
||||||
|
|
||||||
process = psutil.Process()
|
|
||||||
|
|
||||||
for _ in range(num_runs):
|
|
||||||
# Start memory tracking
|
|
||||||
tracemalloc.start()
|
|
||||||
|
|
||||||
# Measure execution time and CPU usage
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
start_cpu_time = process.cpu_times()
|
|
||||||
|
|
||||||
end_cpu_time = process.cpu_times()
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
|
|
||||||
# Calculate metrics
|
|
||||||
execution_time = end_time - start_time
|
|
||||||
cpu_time = (end_cpu_time.user + end_cpu_time.system) - (
|
|
||||||
start_cpu_time.user + start_cpu_time.system
|
|
||||||
)
|
|
||||||
current, peak = tracemalloc.get_traced_memory()
|
|
||||||
|
|
||||||
# Store results
|
|
||||||
execution_times.append(execution_time)
|
|
||||||
peak_memory_usages.append(peak / 1024 / 1024) # Convert to MB
|
|
||||||
cpu_percentages.append((cpu_time / execution_time) * 100)
|
|
||||||
|
|
||||||
tracemalloc.stop()
|
|
||||||
|
|
||||||
analysis = {
|
|
||||||
"mean_execution_time": statistics.mean(execution_times),
|
|
||||||
"mean_peak_memory_mb": statistics.mean(peak_memory_usages),
|
|
||||||
"mean_cpu_percent": statistics.mean(cpu_percentages),
|
|
||||||
"num_runs": num_runs,
|
|
||||||
}
|
|
||||||
|
|
||||||
if num_runs > 1:
|
|
||||||
analysis["std_execution_time"] = statistics.stdev(execution_times)
|
|
||||||
|
|
||||||
return analysis
|
|
||||||
|
|
@ -1,63 +0,0 @@
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from .benchmark_function import benchmark_function
|
|
||||||
|
|
||||||
from cognee.modules.graph.utils import get_graph_from_model
|
|
||||||
from cognee.tests.unit.interfaces.graph.util import (
|
|
||||||
PERSON_NAMES,
|
|
||||||
create_organization_recursive,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Benchmark graph model with configurable recursive depth"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--recursive-depth",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Recursive depth for graph generation (default: 3)",
|
|
||||||
)
|
|
||||||
parser.add_argument("--runs", type=int, default=5, help="Number of benchmark runs (default: 5)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
society = create_organization_recursive(
|
|
||||||
"society", "Society", PERSON_NAMES, args.recursive_depth
|
|
||||||
)
|
|
||||||
added_nodes = {}
|
|
||||||
added_edges = {}
|
|
||||||
visited_properties = {}
|
|
||||||
nodes, edges = asyncio.run(
|
|
||||||
get_graph_from_model(
|
|
||||||
society,
|
|
||||||
added_nodes=added_nodes,
|
|
||||||
added_edges=added_edges,
|
|
||||||
visited_properties=visited_properties,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_graph_from_model_sync(model):
|
|
||||||
added_nodes = {}
|
|
||||||
added_edges = {}
|
|
||||||
visited_properties = {}
|
|
||||||
|
|
||||||
return asyncio.run(
|
|
||||||
get_graph_from_model(
|
|
||||||
model,
|
|
||||||
added_nodes=added_nodes,
|
|
||||||
added_edges=added_edges,
|
|
||||||
visited_properties=visited_properties,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
results = benchmark_function(get_graph_from_model_sync, society, num_runs=args.runs)
|
|
||||||
print("\nBenchmark Results:")
|
|
||||||
print(f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}")
|
|
||||||
print(f"Mean Peak Memory: {results['mean_peak_memory_mb']:.2f} MB")
|
|
||||||
print(f"Mean CPU Usage: {results['mean_cpu_percent']:.2f}%")
|
|
||||||
print(f"Mean Execution Time: {results['mean_execution_time']:.4f} seconds")
|
|
||||||
|
|
||||||
if "std_execution_time" in results:
|
|
||||||
print(f"Execution Time Std: {results['std_execution_time']:.4f} seconds")
|
|
||||||
|
|
@ -1,10 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
|
||||||
|
|
||||||
|
|
||||||
class DummyEmbeddingEngine(EmbeddingEngine):
|
|
||||||
async def embed_text(self, text: list[str]) -> list[list[float]]:
|
|
||||||
return list(list(np.random.randn(3072)))
|
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
|
||||||
return 3072
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
from typing import Type
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import spacy
|
|
||||||
import textacy
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
|
||||||
from cognee.shared.data_models import Edge, KnowledgeGraph, Node, SummarizedContent
|
|
||||||
|
|
||||||
|
|
||||||
class DummyLLMAdapter(LLMInterface):
|
|
||||||
nlp = spacy.load("en_core_web_sm")
|
|
||||||
|
|
||||||
async def acreate_structured_output(
|
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
||||||
) -> BaseModel:
|
|
||||||
if str(response_model) == "<class 'cognee.shared.data_models.SummarizedContent'>":
|
|
||||||
return dummy_summarize_content(text_input)
|
|
||||||
elif str(response_model) == "<class 'cognee.shared.data_models.KnowledgeGraph'>":
|
|
||||||
return dummy_extract_knowledge_graph(text_input, self.nlp)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"Currently dummy acreate_structured_input is only implemented for SummarizedContent and KnowledgeGraph"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_extract_knowledge_graph(text, nlp):
|
|
||||||
doc = nlp(text)
|
|
||||||
triples = list(textacy.extract.subject_verb_object_triples(doc))
|
|
||||||
|
|
||||||
nodes = {}
|
|
||||||
edges = []
|
|
||||||
for triple in triples:
|
|
||||||
source = "_".join([str(e) for e in triple.subject])
|
|
||||||
target = "_".join([str(e) for e in triple.object])
|
|
||||||
nodes[source] = nodes.get(
|
|
||||||
source, Node(id=str(uuid4()), name=source, type="object", description="")
|
|
||||||
)
|
|
||||||
nodes[target] = nodes.get(
|
|
||||||
target, Node(id=str(uuid4()), name=target, type="object", description="")
|
|
||||||
)
|
|
||||||
edge_type = "_".join([str(e) for e in triple.verb])
|
|
||||||
edges.append(
|
|
||||||
Edge(
|
|
||||||
source_node_id=nodes[source].id,
|
|
||||||
target_node_id=nodes[target].id,
|
|
||||||
relationship_name=edge_type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return KnowledgeGraph(nodes=list(nodes.values()), edges=edges)
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_summarize_content(text):
|
|
||||||
words = [(word, len(word)) for word in set(text.split(" "))]
|
|
||||||
words = sorted(words, key=lambda x: x[1], reverse=True)
|
|
||||||
summary = " ".join([word for word, _ in words[:50]])
|
|
||||||
description = " ".join([word for word, _ in words[:10]])
|
|
||||||
return SummarizedContent(summary=summary, description=description)
|
|
||||||