Compare commits
24 commits
main
...
feature/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9590ef760 | ||
|
|
b4b55b820d | ||
|
|
d7d626698d | ||
|
|
aecdff0503 | ||
|
|
4e373cfee7 | ||
|
|
f1e254f357 | ||
|
|
8f5d5b9ac2 | ||
|
|
ce14a441af | ||
|
|
f825732eb2 | ||
|
|
ecdf624bda | ||
|
|
1267f6c1e7 | ||
|
|
d8fde4c527 | ||
|
|
96d1dd772c | ||
|
|
cc52df94b7 | ||
|
|
ad9abb8b76 | ||
|
|
5d4f82fdd4 | ||
|
|
8aae9f8dd8 | ||
|
|
cd813c5732 | ||
|
|
7456567597 | ||
|
|
b29ab72c50 | ||
|
|
5cbdbf3abf | ||
|
|
cc4fab9e75 | ||
|
|
0c1e515c8f | ||
|
|
fe83a25576 |
37 changed files with 2238 additions and 28 deletions
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Union
|
||||
from typing import Union, Optional, Type, List
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.exceptions import UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
|
||||
|
|
@ -13,6 +15,8 @@ async def search(
|
|||
datasets: Union[list[str], str, None] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
|
|
@ -28,6 +32,8 @@ async def search(
|
|||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
0
cognee/complex_demos/__init__.py
Normal file
0
cognee/complex_demos/__init__.py
Normal file
42
cognee/complex_demos/crewai_demo/README
Normal file
42
cognee/complex_demos/crewai_demo/README
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
README:
|
||||
|
||||
This is a demo project to showcase and test how cognee and CrewAI can work together:
|
||||
|
||||
Short description:
|
||||
|
||||
We simulate the hiring process for a technical role. These are the steps of the pipeline:
|
||||
|
||||
1. First we ingest github data including:
|
||||
-commits, comments and other soft skill related information for each of the candidates.
|
||||
-source code and other technical skill related information for each of the candidates.
|
||||
|
||||
2. We hire 3 agents to make the decision using cognee's memory engine
|
||||
|
||||
1 - HR Expert Agent focusing on soft skills:
|
||||
- Analyzes the communication skills, clarity, engagement a kindness based on the commits, comments and github communication of the candidates.
|
||||
-To analyze the soft skills of the candidates, the agent performs multiple searches using cognee.search
|
||||
-The subgraph that the agent can use is limited to the "soft" nodeset subgraph
|
||||
- Scores each candidate from 0 to 1 and gives reasoning
|
||||
|
||||
2 - Technical Expert Agent focusing on technical skills:
|
||||
- Analyzes strictly code related and technical skills based on github commits and pull requests of the candidates.
|
||||
- To analyze the technical skills of the candidates, the agent performs multiple searches using cognee.search
|
||||
- The subgraph that the agent can use is limited to the "techical" nodeset subgraph
|
||||
- Scores each candidate from 0 to 1 and gives reasoning
|
||||
|
||||
3 - CEO/CTO agent who makes the final decision:
|
||||
- Given the output of the HR expert and Technical expert agents, the decision maker agent makes the final decision about the hiring procedure.
|
||||
- The agent will choose the best candidate to hire, and will give reasoning for each of the candidates (why hire/no_hire).
|
||||
|
||||
|
||||
The following tools were implemented:
|
||||
- Cognee build: cognifies the added data (Preliminary task, therefore it is not performed by agents.)
|
||||
- Cognee search: searches the cognee memory, limiting the subgraph using the nodeset subgraph retriever (Used by many agents)
|
||||
- In the case of technical and soft skills agents the tool gets instantiated with the restricted nodeset search capability
|
||||
|
||||
|
||||
The three agents are working together to simulate a hiring process, evaluating soft and technical skills, while the CEO/CTO agent
|
||||
makes the final decision (HIRE/NOHIRE) based on the outputs of the evaluation agents.
|
||||
|
||||
|
||||
Works from IDE and not from CLI for now.
|
||||
0
cognee/complex_demos/crewai_demo/__init__.py
Normal file
0
cognee/complex_demos/crewai_demo/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
User name is John Doe.
|
||||
User is an AI Engineer.
|
||||
User is interested in AI Agents.
|
||||
User is based in San Francisco, California.
|
||||
19
cognee/complex_demos/crewai_demo/pyproject.toml
Normal file
19
cognee/complex_demos/crewai_demo/pyproject.toml
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
[project]
|
||||
name = "crewai_demo"
|
||||
version = "0.1.0"
|
||||
description = "Cognee crewAI demo"
|
||||
authors = [{ name = "Laszlo Hajdu", email = "laszlo@topoteretes.com" }]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.114.0,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
run_crew = "association_layer_demo.main:run"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
0
cognee/complex_demos/crewai_demo/src/__init__.py
Normal file
0
cognee/complex_demos/crewai_demo/src/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from .github_dev_profile import GitHubDevProfile
|
||||
from .github_dev_comments import GitHubDevComments
|
||||
from .github_dev_commits import GitHubDevCommits
|
||||
|
||||
__all__ = ["GitHubDevProfile", "GitHubDevComments", "GitHubDevCommits"]
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
soft_skills_expert_agent:
|
||||
role: >
|
||||
Focused on communication, collaboration, and documentation excellence.
|
||||
goal: >
|
||||
Evaluate README clarity, issue discussions, and community engagement to score
|
||||
communication clarity and open-source culture participation.
|
||||
backstory: >
|
||||
You are an active OSS community manager who values clear writing, inclusive
|
||||
discussion, and strong documentation. You look for evidence of empathy,
|
||||
responsiveness, and collaborative spirit.
|
||||
|
||||
technical_expert_agent:
|
||||
role: >
|
||||
Specialized in evaluating technical skills and code quality.
|
||||
goal: >
|
||||
Analyze repository metadata and commit histories to score coding diversity,
|
||||
depth of contributions, and commit quality.
|
||||
backstory: >
|
||||
You are a seasoned software architect and open-source maintainer. You deeply
|
||||
understand python code structure, language ecosystems, and best practices.
|
||||
Your mission is to objectively rate each candidate’s technical excellence.
|
||||
|
||||
decision_maker_agent:
|
||||
role: >
|
||||
CTO/CEO-level decision maker who integrates expert feedback.
|
||||
goal: >
|
||||
Read the technical and soft-skills evaluations and decide whether to hire
|
||||
each candidate, justifying the decision.
|
||||
backstory: >
|
||||
You are the company’s CTO. You balance technical requirements, team culture,
|
||||
and long-term vision. You weigh both skill scores and communication ratings
|
||||
to make a final hire/no-hire call.
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
soft_skills_assessment_applicant1_task:
|
||||
description: >
|
||||
Search cognee for comments authored by '{applicant_1}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Evaluate their communication clarity, community engagement, and kindness.
|
||||
Ask multiple questions if needed to uncover diverse interactions.
|
||||
Return a complete and reasoned assessment of their soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Sarah Jennings
|
||||
|
||||
Output:
|
||||
- Name: Sarah Jennings
|
||||
- communication_clarity: 0.92
|
||||
- community_engagement: 0.88
|
||||
- kindness: 0.95
|
||||
- reasoning: >
|
||||
Sarah consistently communicates with clarity and structure. In several threads, her responses broke down complex issues into actionable steps,
|
||||
showing strong explanatory skills. She uses inclusive language like “let’s”, “we should”, and frequently thanks others for their input,
|
||||
which indicates a high degree of kindness. Sarah also initiates or joins collaborative threads, offering feedback or connecting people with
|
||||
relevant documentation. Her tone is encouraging and non-defensive, even when correcting misunderstandings. These patterns were observed across
|
||||
over 8 threads involving different team members over a 3-week span.
|
||||
|
||||
expected_output: >
|
||||
- Name: {applicant_1}
|
||||
- communication_clarity (0–1)
|
||||
- community_engagement (0–1)
|
||||
- kindness (0–1)
|
||||
- reasoning: (string)
|
||||
agent: soft_skills_expert_agent
|
||||
|
||||
soft_skills_assessment_applicant2_task:
|
||||
description: >
|
||||
Search cognee for comments authored by '{applicant_2}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Evaluate their communication clarity, community engagement, and kindness.
|
||||
Ask multiple questions if needed to uncover diverse interactions.
|
||||
Return a complete and reasoned assessment of their soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Sarah Jennings
|
||||
|
||||
Output:
|
||||
- Name: Sarah Jennings
|
||||
- communication_clarity: 0.92
|
||||
- community_engagement: 0.88
|
||||
- kindness: 0.95
|
||||
- reasoning: >
|
||||
Sarah consistently communicates with clarity and structure. In several threads, her responses broke down complex issues into actionable steps,
|
||||
showing strong explanatory skills. She uses inclusive language like “let’s”, “we should”, and frequently thanks others for their input,
|
||||
which indicates a high degree of kindness. Sarah also initiates or joins collaborative threads, offering feedback or connecting people with
|
||||
relevant documentation. Her tone is encouraging and non-defensive, even when correcting misunderstandings. These patterns were observed across
|
||||
over 8 threads involving different team members over a 3-week span.
|
||||
|
||||
expected_output: >
|
||||
- Name: {applicant_2}
|
||||
- communication_clarity (0–1)
|
||||
- community_engagement (0–1)
|
||||
- kindness (0–1)
|
||||
- reasoning: (string)
|
||||
agent: soft_skills_expert_agent
|
||||
|
||||
technical_assessment_applicant1_task:
|
||||
description: >
|
||||
Analyze the repository metadata and commit history associated with '{applicant_1}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Score their code_diversity, depth_of_contribution, and commit_quality.
|
||||
Base your assessment strictly on technical input—ignore soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Daniel Murphy
|
||||
|
||||
Output:
|
||||
- Name: Daniel Murphy
|
||||
- code_diversity: 0.87
|
||||
- depth_of_contribution: 0.91
|
||||
- commit_quality: 0.83
|
||||
- reasoning: >
|
||||
Daniel contributed to multiple areas of the codebase including frontend UI components, backend API endpoints, test coverage,
|
||||
and CI/CD configuration. His commit history spans over 6 weeks with consistent activity and includes thoughtful messages
|
||||
(e.g., “refactor auth flow to support multi-tenant login” or “add unit tests for pricing logic edge cases”).
|
||||
His pull requests often include both implementation and tests, showing technical completeness.
|
||||
Several commits show iterative problem-solving and cleanup after peer feedback, indicating thoughtful collaboration
|
||||
and improvement over time.
|
||||
expected_output: >
|
||||
- Name: {applicant_1}
|
||||
- code_diversity (0–1)
|
||||
- depth_of_contribution (0–1)
|
||||
- commit_quality (0–1)
|
||||
- reasoning: (string)
|
||||
agent: technical_expert_agent
|
||||
|
||||
technical_assessment_applicant2_task:
|
||||
description: >
|
||||
Analyze the repository metadata and commit history associated with '{applicant_2}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Score their code_diversity, depth_of_contribution, and commit_quality.
|
||||
Base your assessment strictly on technical input—ignore soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Daniel Murphy
|
||||
|
||||
Output:
|
||||
- Name: Daniel Murphy
|
||||
- code_diversity: 0.87
|
||||
- depth_of_contribution: 0.91
|
||||
- commit_quality: 0.83
|
||||
- reasoning: >
|
||||
Daniel contributed to multiple areas of the codebase including frontend UI components, backend API endpoints, test coverage,
|
||||
and CI/CD configuration. His commit history spans over 6 weeks with consistent activity and includes thoughtful messages
|
||||
(e.g., “refactor auth flow to support multi-tenant login” or “add unit tests for pricing logic edge cases”).
|
||||
His pull requests often include both implementation and tests, showing technical completeness.
|
||||
Several commits show iterative problem-solving and cleanup after peer feedback, indicating thoughtful collaboration
|
||||
and improvement over time.
|
||||
|
||||
expected_output: >
|
||||
- Name: {applicant_2}
|
||||
- code_diversity (0–1)
|
||||
- depth_of_contribution (0–1)
|
||||
- commit_quality (0–1)
|
||||
- reasoning: (string)
|
||||
agent: technical_expert_agent
|
||||
|
||||
hiring_decision_task:
|
||||
description: >
|
||||
Review the technical and soft skill assessment task outputs for candidates: -{applicant_1} and -{applicant_2},
|
||||
then decide HIRE or NO_HIRE for each candidate with a detailed reasoning.
|
||||
The people to evaluate are:
|
||||
-{applicant_1}
|
||||
-{applicant_2}
|
||||
We have to hire one of them.
|
||||
|
||||
Prepare the final output for the ingest_hiring_decision_task.
|
||||
|
||||
|
||||
expected_output: >
|
||||
A string strictly containing the following for each person:
|
||||
- Person
|
||||
- decision: "HIRE" or "NO_HIRE",
|
||||
- reasoning: (string)
|
||||
agent: decision_maker_agent
|
||||
|
||||
ingest_hiring_decision_task:
|
||||
description: >
|
||||
Take the final hiring decision from the hiring_decision_task report and ingest it into Cognee using the "ingest_report_to_cognee" tool.
|
||||
Do not re-evaluate—just save the result using the tool you have.
|
||||
expected_output: >
|
||||
- confirmation: string message confirming successful ingestion into Cognee
|
||||
agent: decision_maker_agent
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class CogneeBuild(BaseTool):
|
||||
name: str = "Cognee Build"
|
||||
description: str = "Creates a memory and builds a knowledge graph using cognee."
|
||||
|
||||
def _run(self, inputs) -> str:
|
||||
import cognee
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
for meta in inputs.values():
|
||||
text = meta["file_content"]
|
||||
node_set = meta["nodeset"]
|
||||
await cognee.add(text, node_set=node_set)
|
||||
|
||||
await cognee.cognify()
|
||||
|
||||
return "Knowledge Graph is done."
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
results = loop.run_until_complete(main())
|
||||
return results
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
from crewai.tools import BaseTool
|
||||
from typing import Type, List
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
import asyncio
|
||||
|
||||
|
||||
class CogneeIngestionInput(BaseModel):
|
||||
text: str = Field(
|
||||
"",
|
||||
description="The text of the report The format you should follow is {'text': 'your report'}",
|
||||
)
|
||||
|
||||
|
||||
class CogneeIngestion(BaseTool):
|
||||
name: str = "ingest_report_to_cognee"
|
||||
description: str = "This tool can be used to ingest the final hiring report into cognee"
|
||||
args_schema: Type[BaseModel] = CogneeIngestionInput
|
||||
_nodeset_name: str
|
||||
|
||||
def __init__(self, nodeset_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._nodeset_name = nodeset_name
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
import cognee
|
||||
from secrets import choice
|
||||
from string import ascii_letters, digits
|
||||
|
||||
async def main():
|
||||
try:
|
||||
hash6 = "".join(choice(ascii_letters + digits) for _ in range(6))
|
||||
await cognee.add(text, node_set=[self._nodeset_name], dataset_name=hash6)
|
||||
await cognee.cognify(datasets=hash6)
|
||||
|
||||
return "Report ingested successfully into Cognee memory."
|
||||
except Exception as e:
|
||||
return f"Error during ingestion: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(main())
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
from crewai.tools import BaseTool
|
||||
from typing import Type
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
|
||||
|
||||
class CogneeSearchInput(BaseModel):
|
||||
query: str = Field(
|
||||
"",
|
||||
description="The natural language question to ask the memory engine."
|
||||
"The format you should follow is {'query': 'your query'}",
|
||||
)
|
||||
|
||||
|
||||
class CogneeSearch(BaseTool):
|
||||
name: str = "search_from_cognee"
|
||||
description: str = (
|
||||
"Use this tool to search the Cognee memory graph. "
|
||||
"Provide a natural language query that describes the information you want to retrieve, "
|
||||
"such as comments authored or files changes by a specific person."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CogneeSearchInput
|
||||
_nodeset_name: str = PrivateAttr()
|
||||
|
||||
def __init__(self, nodeset_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._nodeset_name = nodeset_name
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
import asyncio
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
||||
async def main():
|
||||
try:
|
||||
print(query)
|
||||
|
||||
search_results = await GraphCompletionRetriever(
|
||||
top_k=5,
|
||||
node_type=NodeSet,
|
||||
node_name=[self._nodeset_name],
|
||||
).get_context(query=query)
|
||||
|
||||
return search_results
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
search_results = loop.run_until_complete(main())
|
||||
return search_results
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
from crewai.tools import BaseTool
|
||||
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from ..github_ingest_datapoints import cognify_github_data_from_username
|
||||
|
||||
|
||||
class GithubIngestion(BaseTool):
|
||||
name: str = "Github graph builder"
|
||||
description: str = "Ingests the github graph of a person into Cognee"
|
||||
|
||||
def _run(self, applicant_1, applicant_2) -> str:
|
||||
import asyncio
|
||||
import cognee
|
||||
import os
|
||||
from cognee.low_level import DataPoint, setup as cognee_setup
|
||||
|
||||
async def main():
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await cognee_setup()
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
await cognify_github_data_from_username(applicant_1, token)
|
||||
await cognify_github_data_from_username(applicant_2, token)
|
||||
|
||||
return "Github ingestion finished"
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
results = loop.run_until_complete(main())
|
||||
return results
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import requests
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
GITHUB_API_URL = "https://api.github.com/graphql"
|
||||
|
||||
logger = get_logger("github_comments")
|
||||
|
||||
|
||||
class GitHubCommentBase(ABC):
|
||||
"""Base class for GitHub comment providers."""
|
||||
|
||||
def __init__(self, token, username, limit=10):
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.limit = limit
|
||||
|
||||
def _run_query(self, query: str) -> dict:
|
||||
"""Executes a GraphQL query against GitHub's API."""
|
||||
headers = {"Authorization": f"Bearer {self.token}"}
|
||||
response = requests.post(GITHUB_API_URL, json={"query": query}, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Query failed: {response.status_code} - {response.text}")
|
||||
return response.json()["data"]
|
||||
|
||||
def get_comments(self):
|
||||
"""Template method that orchestrates the comment retrieval process."""
|
||||
try:
|
||||
query = self._build_query()
|
||||
data = self._run_query(query)
|
||||
raw_comments = self._extract_comments(data)
|
||||
return [self._format_comment(item) for item in raw_comments[: self.limit]]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching {self._get_comment_type()} comments: {e}")
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query string."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts the comment data from the GraphQL response."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _format_comment(self, item) -> dict:
|
||||
"""Formats a single comment."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the type of comment this provider handles."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,298 @@
|
|||
from datetime import datetime, timedelta
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_comment_base import (
|
||||
GitHubCommentBase,
|
||||
logger,
|
||||
)
|
||||
|
||||
|
||||
class IssueCommentsProvider(GitHubCommentBase):
|
||||
"""Provider for GitHub issue comments."""
|
||||
|
||||
QUERY_TEMPLATE = """
|
||||
{{
|
||||
user(login: "{username}") {{
|
||||
issueComments(first: {limit}, orderBy: {{field: UPDATED_AT, direction: DESC}}) {{
|
||||
nodes {{
|
||||
body
|
||||
createdAt
|
||||
updatedAt
|
||||
url
|
||||
issue {{
|
||||
number
|
||||
title
|
||||
url
|
||||
repository {{
|
||||
nameWithOwner
|
||||
}}
|
||||
state
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query for issue comments."""
|
||||
return self.QUERY_TEMPLATE.format(username=self.username, limit=self.limit)
|
||||
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts issue comments from the GraphQL response."""
|
||||
return data["user"]["issueComments"]["nodes"]
|
||||
|
||||
def _format_comment(self, comment) -> dict:
|
||||
"""Formats an issue comment from GraphQL."""
|
||||
comment_id = comment["url"].split("/")[-1] if comment["url"] else None
|
||||
|
||||
return {
|
||||
"repo": comment["issue"]["repository"]["nameWithOwner"],
|
||||
"issue_number": comment["issue"]["number"],
|
||||
"comment_id": comment_id,
|
||||
"body": comment["body"],
|
||||
"text": comment["body"],
|
||||
"created_at": comment["createdAt"],
|
||||
"updated_at": comment["updatedAt"],
|
||||
"html_url": comment["url"],
|
||||
"issue_url": comment["issue"]["url"],
|
||||
"author_association": "COMMENTER",
|
||||
"issue_title": comment["issue"]["title"],
|
||||
"issue_state": comment["issue"]["state"],
|
||||
"login": self.username,
|
||||
"type": "issue_comment",
|
||||
}
|
||||
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the comment type for error messages."""
|
||||
return "issue"
|
||||
|
||||
|
||||
class PrReviewsProvider(GitHubCommentBase):
|
||||
"""Provider for GitHub PR reviews."""
|
||||
|
||||
QUERY_TEMPLATE = """
|
||||
{{
|
||||
user(login: "{username}") {{
|
||||
contributionsCollection {{
|
||||
pullRequestReviewContributions(first: {fetch_limit}) {{
|
||||
nodes {{
|
||||
pullRequestReview {{
|
||||
body
|
||||
createdAt
|
||||
updatedAt
|
||||
url
|
||||
state
|
||||
pullRequest {{
|
||||
number
|
||||
title
|
||||
url
|
||||
repository {{
|
||||
nameWithOwner
|
||||
}}
|
||||
state
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
def __init__(self, token, username, limit=10, fetch_limit=None):
|
||||
"""Initialize with token, username, and optional limits."""
|
||||
super().__init__(token, username, limit)
|
||||
self.fetch_limit = fetch_limit if fetch_limit is not None else 10 * limit
|
||||
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query for PR reviews."""
|
||||
return self.QUERY_TEMPLATE.format(username=self.username, fetch_limit=self.fetch_limit)
|
||||
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts PR reviews from the GraphQL response."""
|
||||
contributions = data["user"]["contributionsCollection"]["pullRequestReviewContributions"][
|
||||
"nodes"
|
||||
]
|
||||
return [
|
||||
node["pullRequestReview"] for node in contributions if node["pullRequestReview"]["body"]
|
||||
]
|
||||
|
||||
def _format_comment(self, review) -> dict:
|
||||
"""Formats a PR review from GraphQL."""
|
||||
review_id = review["url"].split("/")[-1] if review["url"] else None
|
||||
|
||||
return {
|
||||
"repo": review["pullRequest"]["repository"]["nameWithOwner"],
|
||||
"issue_number": review["pullRequest"]["number"],
|
||||
"comment_id": review_id,
|
||||
"body": review["body"],
|
||||
"text": review["body"],
|
||||
"created_at": review["createdAt"],
|
||||
"updated_at": review["updatedAt"],
|
||||
"html_url": review["url"],
|
||||
"issue_url": review["pullRequest"]["url"],
|
||||
"author_association": "COMMENTER",
|
||||
"issue_title": review["pullRequest"]["title"],
|
||||
"issue_state": review["pullRequest"]["state"],
|
||||
"login": self.username,
|
||||
"review_state": review["state"],
|
||||
"type": "pr_review",
|
||||
}
|
||||
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the comment type for error messages."""
|
||||
return "PR review"
|
||||
|
||||
|
||||
class PrReviewCommentsProvider(GitHubCommentBase):
|
||||
"""Provider for GitHub PR review comments (inline code comments)."""
|
||||
|
||||
PR_CONTRIBUTIONS_TEMPLATE = """
|
||||
{{
|
||||
user(login: "{username}") {{
|
||||
contributionsCollection {{
|
||||
pullRequestReviewContributions(first: {fetch_limit}) {{
|
||||
nodes {{
|
||||
pullRequestReview {{
|
||||
pullRequest {{
|
||||
number
|
||||
title
|
||||
url
|
||||
repository {{
|
||||
nameWithOwner
|
||||
}}
|
||||
state
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
PR_COMMENTS_TEMPLATE = """
|
||||
{{
|
||||
repository(owner: "{owner}", name: "{repo}") {{
|
||||
pullRequest(number: {pr_number}) {{
|
||||
reviews(first: {reviews_limit}, author: "{username}") {{
|
||||
nodes {{
|
||||
comments(first: {comments_limit}) {{
|
||||
nodes {{
|
||||
body
|
||||
createdAt
|
||||
updatedAt
|
||||
url
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token,
|
||||
username,
|
||||
limit=10,
|
||||
fetch_limit=None,
|
||||
reviews_limit=None,
|
||||
comments_limit=None,
|
||||
pr_limit=None,
|
||||
):
|
||||
"""Initialize with token, username, and optional limits."""
|
||||
super().__init__(token, username, limit)
|
||||
self.fetch_limit = fetch_limit if fetch_limit is not None else 4 * limit
|
||||
self.reviews_limit = reviews_limit if reviews_limit is not None else 2 * limit
|
||||
self.comments_limit = comments_limit if comments_limit is not None else 3 * limit
|
||||
self.pr_limit = pr_limit if pr_limit is not None else 2 * limit
|
||||
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query for PR contributions."""
|
||||
return self.PR_CONTRIBUTIONS_TEMPLATE.format(
|
||||
username=self.username, fetch_limit=self.fetch_limit
|
||||
)
|
||||
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts PR review comments using a two-step approach."""
|
||||
prs = self._get_reviewed_prs(data)
|
||||
return self._fetch_comments_for_prs(prs)
|
||||
|
||||
def _get_reviewed_prs(self, data) -> list:
|
||||
"""Gets a deduplicated list of PRs the user has reviewed."""
|
||||
contributions = data["user"]["contributionsCollection"]["pullRequestReviewContributions"][
|
||||
"nodes"
|
||||
]
|
||||
unique_prs = []
|
||||
|
||||
for node in contributions:
|
||||
pr = node["pullRequestReview"]["pullRequest"]
|
||||
if not any(existing_pr["url"] == pr["url"] for existing_pr in unique_prs):
|
||||
unique_prs.append(pr)
|
||||
|
||||
return unique_prs[: min(self.pr_limit, len(unique_prs))]
|
||||
|
||||
def _fetch_comments_for_prs(self, prs) -> list:
|
||||
"""Fetches inline comments for each PR in the list."""
|
||||
all_comments = []
|
||||
|
||||
for pr in prs:
|
||||
comments = self._get_comments_for_pr(pr)
|
||||
all_comments.extend(comments)
|
||||
|
||||
return all_comments
|
||||
|
||||
def _get_comments_for_pr(self, pr) -> list:
|
||||
"""Fetches the inline comments for a specific PR."""
|
||||
owner, repo = pr["repository"]["nameWithOwner"].split("/")
|
||||
|
||||
pr_query = self.PR_COMMENTS_TEMPLATE.format(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
pr_number=pr["number"],
|
||||
username=self.username,
|
||||
reviews_limit=self.reviews_limit,
|
||||
comments_limit=self.comments_limit,
|
||||
)
|
||||
|
||||
try:
|
||||
pr_comments = []
|
||||
pr_data = self._run_query(pr_query)
|
||||
reviews = pr_data["repository"]["pullRequest"]["reviews"]["nodes"]
|
||||
|
||||
for review in reviews:
|
||||
for comment in review["comments"]["nodes"]:
|
||||
comment["_pr_data"] = pr
|
||||
pr_comments.append(comment)
|
||||
|
||||
return pr_comments
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching comments for PR #{pr['number']}: {e}")
|
||||
return []
|
||||
|
||||
def _format_comment(self, comment) -> dict:
|
||||
"""Formats a PR review comment from GraphQL."""
|
||||
pr = comment["_pr_data"]
|
||||
comment_id = comment["url"].split("/")[-1] if comment["url"] else None
|
||||
|
||||
return {
|
||||
"repo": pr["repository"]["nameWithOwner"],
|
||||
"issue_number": pr["number"],
|
||||
"comment_id": comment_id,
|
||||
"body": comment["body"],
|
||||
"text": comment["body"],
|
||||
"created_at": comment["createdAt"],
|
||||
"updated_at": comment["updatedAt"],
|
||||
"html_url": comment["url"],
|
||||
"issue_url": pr["url"],
|
||||
"author_association": "COMMENTER",
|
||||
"issue_title": pr["title"],
|
||||
"issue_state": pr["state"],
|
||||
"login": self.username,
|
||||
"type": "pr_review_comment",
|
||||
}
|
||||
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the comment type for error messages."""
|
||||
return "PR review comment"
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
from cognee.low_level import DataPoint
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_datapoints import (
|
||||
GitHubUser,
|
||||
Repository,
|
||||
File,
|
||||
FileChange,
|
||||
Comment,
|
||||
Issue,
|
||||
Commit,
|
||||
)
|
||||
|
||||
logger = get_logger("github_datapoints")
|
||||
|
||||
|
||||
def create_github_user_datapoint(user_data, nodesets: List[NodeSet]):
|
||||
"""Creates just the GitHubUser DataPoint object from the user data, with node sets."""
|
||||
if not user_data:
|
||||
return None
|
||||
|
||||
user_id = uuid5(NAMESPACE_OID, user_data.get("login", ""))
|
||||
|
||||
user = GitHubUser(
|
||||
id=user_id,
|
||||
name=user_data.get("login", ""),
|
||||
bio=user_data.get("bio"),
|
||||
company=user_data.get("company"),
|
||||
location=user_data.get("location"),
|
||||
public_repos=user_data.get("public_repos", 0),
|
||||
followers=user_data.get("followers", 0),
|
||||
following=user_data.get("following", 0),
|
||||
interacts_with=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
|
||||
logger.debug(f"Created GitHubUser with ID: {user_id}")
|
||||
|
||||
return [user] + nodesets
|
||||
|
||||
|
||||
def create_repository_datapoint(repo_name: str, nodesets: List[NodeSet]) -> Repository:
|
||||
"""Creates a Repository DataPoint with a consistent ID."""
|
||||
repo_id = uuid5(NAMESPACE_OID, repo_name)
|
||||
repo = Repository(
|
||||
id=repo_id,
|
||||
name=repo_name,
|
||||
has_issue=[],
|
||||
has_commit=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Repository with ID: {repo_id} for {repo_name}")
|
||||
return repo
|
||||
|
||||
|
||||
def create_file_datapoint(filename: str, repo_name: str, nodesets: List[NodeSet]) -> File:
|
||||
"""Creates a File DataPoint with a consistent ID."""
|
||||
file_key = f"{repo_name}:{filename}"
|
||||
file_id = uuid5(NAMESPACE_OID, file_key)
|
||||
file = File(id=file_id, filename=filename, repo=repo_name, belongs_to_set=nodesets)
|
||||
logger.debug(f"Created File with ID: {file_id} for {filename}")
|
||||
return file
|
||||
|
||||
|
||||
def create_commit_datapoint(
|
||||
commit_data: Dict[str, Any], user: GitHubUser, nodesets: List[NodeSet]
|
||||
) -> Commit:
|
||||
"""Creates a Commit DataPoint with a consistent ID and connection to user."""
|
||||
commit_id = uuid5(NAMESPACE_OID, commit_data.get("commit_sha", ""))
|
||||
commit = Commit(
|
||||
id=commit_id,
|
||||
commit_sha=commit_data.get("commit_sha", ""),
|
||||
text="Commit message:" + (str)(commit_data.get("commit_message", "")),
|
||||
commit_date=commit_data.get("commit_date", ""),
|
||||
commit_url=commit_data.get("commit_url", ""),
|
||||
author_name=commit_data.get("login", ""),
|
||||
repo=commit_data.get("repo", ""),
|
||||
has_change=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Commit with ID: {commit_id} for {commit_data.get('commit_sha', '')}")
|
||||
return commit
|
||||
|
||||
|
||||
def create_file_change_datapoint(
|
||||
fc_data: Dict[str, Any], user: GitHubUser, file: File, nodesets: List[NodeSet]
|
||||
) -> FileChange:
|
||||
"""Creates a FileChange DataPoint with a consistent ID."""
|
||||
fc_key = (
|
||||
f"{fc_data.get('repo', '')}:{fc_data.get('commit_sha', '')}:{fc_data.get('filename', '')}"
|
||||
)
|
||||
fc_id = uuid5(NAMESPACE_OID, fc_key)
|
||||
|
||||
file_change = FileChange(
|
||||
id=fc_id,
|
||||
filename=fc_data.get("filename", ""),
|
||||
status=fc_data.get("status", ""),
|
||||
additions=fc_data.get("additions", 0),
|
||||
deletions=fc_data.get("deletions", 0),
|
||||
changes=fc_data.get("changes", 0),
|
||||
text=fc_data.get("diff", ""),
|
||||
commit_sha=fc_data.get("commit_sha", ""),
|
||||
repo=fc_data.get("repo", ""),
|
||||
modifies=file.filename,
|
||||
changed_by=user,
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created FileChange with ID: {fc_id} for {fc_data.get('filename', '')}")
|
||||
return file_change
|
||||
|
||||
|
||||
def create_issue_datapoint(
|
||||
issue_data: Dict[str, Any], repo_name: str, nodesets: List[NodeSet]
|
||||
) -> Issue:
|
||||
"""Creates an Issue DataPoint with a consistent ID."""
|
||||
issue_key = f"{repo_name}:{issue_data.get('issue_number', '')}"
|
||||
issue_id = uuid5(NAMESPACE_OID, issue_key)
|
||||
|
||||
issue = Issue(
|
||||
id=issue_id,
|
||||
number=issue_data.get("issue_number", 0),
|
||||
text=issue_data.get("issue_title", ""),
|
||||
state=issue_data.get("issue_state", ""),
|
||||
repository=repo_name,
|
||||
is_pr=False,
|
||||
has_comment=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Issue with ID: {issue_id} for {issue_data.get('issue_title', '')}")
|
||||
return issue
|
||||
|
||||
|
||||
def create_comment_datapoint(
|
||||
comment_data: Dict[str, Any], user: GitHubUser, nodesets: List[NodeSet]
|
||||
) -> Comment:
|
||||
"""Creates a Comment DataPoint with a consistent ID and connection to user."""
|
||||
comment_key = f"{comment_data.get('repo', '')}:{comment_data.get('issue_number', '')}:{comment_data.get('comment_id', '')}"
|
||||
comment_id = uuid5(NAMESPACE_OID, comment_key)
|
||||
|
||||
comment = Comment(
|
||||
id=comment_id,
|
||||
comment_id=str(comment_data.get("comment_id", "")),
|
||||
text=comment_data.get("body", ""),
|
||||
created_at=comment_data.get("created_at", ""),
|
||||
updated_at=comment_data.get("updated_at", ""),
|
||||
author_name=comment_data.get("login", ""),
|
||||
issue_number=comment_data.get("issue_number", 0),
|
||||
repo=comment_data.get("repo", ""),
|
||||
authored_by=user,
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Comment with ID: {comment_id}")
|
||||
return comment
|
||||
|
||||
|
||||
def create_github_datapoints(github_data, nodesets: List[NodeSet]):
|
||||
"""Creates DataPoint objects from GitHub data - simplified to just create user for now."""
|
||||
if not github_data:
|
||||
return None
|
||||
|
||||
return create_github_user_datapoint(github_data["user"], nodesets)
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional, List
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class File(DataPoint):
|
||||
"""File is now a leaf node without any lists of other DataPoints"""
|
||||
|
||||
filename: str
|
||||
repo: str
|
||||
metadata: dict = {"index_fields": ["filename"]}
|
||||
|
||||
|
||||
class GitHubUser(DataPoint):
|
||||
name: Optional[str]
|
||||
bio: Optional[str]
|
||||
company: Optional[str]
|
||||
location: Optional[str]
|
||||
public_repos: int
|
||||
followers: int
|
||||
following: int
|
||||
interacts_with: List["Repository"] = []
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class FileChange(DataPoint):
|
||||
filename: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
changes: int
|
||||
text: str
|
||||
commit_sha: str
|
||||
repo: str
|
||||
modifies: str
|
||||
changed_by: GitHubUser
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class Comment(DataPoint):
|
||||
comment_id: str
|
||||
text: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
author_name: str
|
||||
issue_number: int
|
||||
repo: str
|
||||
authored_by: GitHubUser
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class Issue(DataPoint):
|
||||
number: int
|
||||
text: str
|
||||
state: str
|
||||
repository: str
|
||||
is_pr: bool
|
||||
has_comment: List[Comment] = []
|
||||
|
||||
|
||||
class Commit(DataPoint):
|
||||
commit_sha: str
|
||||
text: str
|
||||
commit_date: str
|
||||
commit_url: str
|
||||
author_name: str
|
||||
repo: str
|
||||
has_change: List[FileChange] = []
|
||||
|
||||
|
||||
class Repository(DataPoint):
|
||||
name: str
|
||||
has_issue: List[Issue] = []
|
||||
has_commit: List[Commit] = []
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from github import Github
|
||||
from datetime import datetime
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_comment_providers import (
|
||||
IssueCommentsProvider,
|
||||
PrReviewsProvider,
|
||||
PrReviewCommentsProvider,
|
||||
)
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_comment_base import logger
|
||||
|
||||
|
||||
class GitHubDevComments:
|
||||
"""Facade class for working with a GitHub developer's comments."""
|
||||
|
||||
def __init__(self, profile, limit=10, include_issue_details=True):
|
||||
"""Initialize with a GitHubDevProfile instance and default parameters."""
|
||||
self.profile = profile
|
||||
self.limit = limit
|
||||
self.include_issue_details = include_issue_details
|
||||
|
||||
def get_issue_comments(self):
|
||||
"""Fetches the most recent comments made by the user on issues and PRs across repositories."""
|
||||
if not self.profile.user:
|
||||
logger.warning(f"No user found for profile {self.profile.username}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Fetching comments for {self.profile.username} with limit={self.limit}")
|
||||
|
||||
# Create providers with just the basic limit - they will handle their own multipliers
|
||||
issue_provider = IssueCommentsProvider(
|
||||
self.profile.token, self.profile.username, self.limit
|
||||
)
|
||||
pr_review_provider = PrReviewsProvider(
|
||||
self.profile.token, self.profile.username, self.limit
|
||||
)
|
||||
pr_comment_provider = PrReviewCommentsProvider(
|
||||
self.profile.token, self.profile.username, self.limit
|
||||
)
|
||||
|
||||
issue_comments = issue_provider.get_comments()
|
||||
pr_reviews = pr_review_provider.get_comments()
|
||||
pr_review_comments = pr_comment_provider.get_comments()
|
||||
|
||||
total_comments = issue_comments + pr_reviews + pr_review_comments
|
||||
logger.info(
|
||||
f"Retrieved {len(total_comments)} comments for {self.profile.username} "
|
||||
f"({len(issue_comments)} issue, {len(pr_reviews)} PR reviews, "
|
||||
f"{len(pr_review_comments)} PR review comments)"
|
||||
)
|
||||
|
||||
return total_comments
|
||||
|
||||
def set_limit(self, limit=None, include_issue_details=None):
|
||||
"""Sets the limit for comments to retrieve."""
|
||||
if limit is not None:
|
||||
self.limit = limit
|
||||
if include_issue_details is not None:
|
||||
self.include_issue_details = include_issue_details
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
from github import Github
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class GitHubDevCommits:
|
||||
"""Class for working with a GitHub developer's commits in pull requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
profile,
|
||||
days=30,
|
||||
prs_limit=10,
|
||||
commits_per_pr=5,
|
||||
include_files=False,
|
||||
skip_no_diff=False,
|
||||
):
|
||||
"""Initialize with a GitHubDevProfile instance and default parameters."""
|
||||
self.profile = profile
|
||||
self.days = days
|
||||
self.prs_limit = prs_limit
|
||||
self.commits_per_pr = commits_per_pr
|
||||
self.include_files = include_files
|
||||
self.skip_no_diff = skip_no_diff
|
||||
self.file_keys = ["filename", "status", "additions", "deletions", "changes", "diff"]
|
||||
|
||||
def get_user_commits(self):
|
||||
"""Fetches user's most recent commits from pull requests."""
|
||||
if not self.profile.user:
|
||||
return None
|
||||
|
||||
commits = self._collect_user_pr_commits()
|
||||
return {"user": self.profile.get_user_info(), "commits": commits}
|
||||
|
||||
def get_user_file_changes(self):
|
||||
"""Returns a flat list of file changes with associated commit information from PRs."""
|
||||
if not self.profile.user:
|
||||
return None
|
||||
|
||||
all_files = []
|
||||
commits = self._collect_user_pr_commits(include_files=True)
|
||||
|
||||
for commit in commits:
|
||||
if "files" not in commit:
|
||||
continue
|
||||
|
||||
commit_info = {
|
||||
"repo": commit["repo"],
|
||||
"commit_sha": commit["sha"],
|
||||
"commit_message": commit["message"],
|
||||
"commit_date": commit["date"],
|
||||
"commit_url": commit["url"],
|
||||
"pr_number": commit.get("pr_number"),
|
||||
"pr_title": commit.get("pr_title"),
|
||||
}
|
||||
|
||||
file_changes = []
|
||||
for file in commit["files"]:
|
||||
file_data = {key: file.get(key) for key in self.file_keys}
|
||||
file_changes.append({**file_data, **commit_info})
|
||||
|
||||
all_files.extend(file_changes)
|
||||
|
||||
return all_files
|
||||
|
||||
def set_options(
|
||||
self, days=None, prs_limit=None, commits_per_pr=None, include_files=None, skip_no_diff=None
|
||||
):
|
||||
"""Sets commit search parameters."""
|
||||
if days is not None:
|
||||
self.days = days
|
||||
if prs_limit is not None:
|
||||
self.prs_limit = prs_limit
|
||||
if commits_per_pr is not None:
|
||||
self.commits_per_pr = commits_per_pr
|
||||
if include_files is not None:
|
||||
self.include_files = include_files
|
||||
if skip_no_diff is not None:
|
||||
self.skip_no_diff = skip_no_diff
|
||||
|
||||
def _get_date_filter(self, days):
|
||||
"""Creates a date filter string for GitHub search queries."""
|
||||
if not days:
|
||||
return ""
|
||||
|
||||
date_limit = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
return f" created:>={date_limit}"
|
||||
|
||||
def _collect_user_pr_commits(self, include_files=None):
|
||||
"""Collects and sorts a user's recent commits from pull requests they authored."""
|
||||
include_files = include_files if include_files is not None else self.include_files
|
||||
|
||||
prs = self._get_user_prs()
|
||||
|
||||
if not prs:
|
||||
return []
|
||||
|
||||
all_commits = []
|
||||
for pr in prs[: self.prs_limit]:
|
||||
pr_commits = self._get_commits_from_pr(pr, include_files)
|
||||
all_commits.extend(pr_commits)
|
||||
|
||||
sorted_commits = sorted(all_commits, key=lambda x: x["date"], reverse=True)
|
||||
return sorted_commits
|
||||
|
||||
def _get_user_prs(self):
|
||||
"""Gets pull requests authored by the user."""
|
||||
date_filter = self._get_date_filter(self.days)
|
||||
query = f"author:{self.profile.username} is:pr is:merged{date_filter}"
|
||||
|
||||
try:
|
||||
return list(self.profile.github.search_issues(query))
|
||||
except Exception as e:
|
||||
print(f"Error searching for PRs: {e}")
|
||||
return []
|
||||
|
||||
def _get_commits_from_pr(self, pr_issue, include_files=None):
|
||||
"""Gets commits by the user from a specific PR."""
|
||||
include_files = include_files if include_files is not None else self.include_files
|
||||
|
||||
pr_info = self._get_pull_request_object(pr_issue)
|
||||
if not pr_info:
|
||||
return []
|
||||
|
||||
repo_name, pr = pr_info
|
||||
|
||||
all_commits = self._get_all_pr_commits(pr, pr_issue.number)
|
||||
if not all_commits:
|
||||
return []
|
||||
|
||||
user_commits = [
|
||||
c
|
||||
for c in all_commits
|
||||
if c.author and hasattr(c.author, "login") and c.author.login == self.profile.username
|
||||
]
|
||||
|
||||
commit_data = [
|
||||
self._extract_commit_data(commit, repo_name, pr_issue, include_files)
|
||||
for commit in user_commits[: self.commits_per_pr]
|
||||
]
|
||||
|
||||
return commit_data
|
||||
|
||||
def _get_pull_request_object(self, pr_issue):
|
||||
"""Gets repository and pull request objects from an issue."""
|
||||
try:
|
||||
repo_name = pr_issue.repository.full_name
|
||||
repo = self.profile.github.get_repo(repo_name)
|
||||
pr = repo.get_pull(pr_issue.number)
|
||||
return (repo_name, pr)
|
||||
except Exception as e:
|
||||
print(f"Error accessing PR #{pr_issue.number}: {e}")
|
||||
return None
|
||||
|
||||
def _get_all_pr_commits(self, pr, pr_number):
|
||||
"""Gets all commits from a pull request."""
|
||||
try:
|
||||
return list(pr.get_commits())
|
||||
except Exception as e:
|
||||
print(f"Error retrieving commits from PR #{pr_number}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_commit_data(self, commit, repo_name, pr_issue, include_files=None):
|
||||
"""Extracts relevant data from a commit object within a PR context."""
|
||||
commit_data = {
|
||||
"repo": repo_name,
|
||||
"sha": commit.sha,
|
||||
"message": commit.commit.message,
|
||||
"date": commit.commit.author.date,
|
||||
"url": commit.html_url,
|
||||
"pr_number": pr_issue.number,
|
||||
"pr_title": pr_issue.title,
|
||||
"pr_url": pr_issue.html_url,
|
||||
}
|
||||
|
||||
include_files = include_files if include_files is not None else self.include_files
|
||||
|
||||
if include_files:
|
||||
commit_data["files"] = self._extract_commit_files(commit)
|
||||
|
||||
return commit_data
|
||||
|
||||
def _extract_commit_files(self, commit):
|
||||
"""Extracts files changed in a commit, including diffs."""
|
||||
files = []
|
||||
for file in commit.files:
|
||||
if self.skip_no_diff and not file.patch:
|
||||
continue
|
||||
|
||||
file_data = {key: getattr(file, key, None) for key in self.file_keys}
|
||||
|
||||
if "diff" in self.file_keys:
|
||||
file_data["diff"] = file.patch if file.patch else "No diff available for this file"
|
||||
|
||||
files.append(file_data)
|
||||
return files
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
from github import Github
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_dev_comments import GitHubDevComments
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_dev_commits import GitHubDevCommits
|
||||
|
||||
|
||||
class GitHubDevProfile:
|
||||
"""Class for working with a GitHub developer's profile, commits, and activity."""
|
||||
|
||||
def __init__(self, username, token):
|
||||
"""Initialize with a username and GitHub API token."""
|
||||
self.github = Github(token) if token else Github()
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.user = self._get_user(username)
|
||||
self.user_info = self._extract_user_info() if self.user else None
|
||||
self.comments = GitHubDevComments(self) if self.user else None
|
||||
self.commits = GitHubDevCommits(self) if self.user else None
|
||||
|
||||
def get_user_info(self):
|
||||
"""Returns the cached user information."""
|
||||
return self.user_info
|
||||
|
||||
def get_user_repos(self, limit=None):
|
||||
"""Returns a list of user's repositories with limit."""
|
||||
if not self.user:
|
||||
return []
|
||||
|
||||
repos = list(self.user.get_repos())
|
||||
if limit:
|
||||
repos = repos[:limit]
|
||||
return repos
|
||||
|
||||
def get_user_commits(self, days=30, prs_limit=5, commits_per_pr=3, include_files=False):
|
||||
"""Fetches user's most recent commits from pull requests."""
|
||||
if not self.commits:
|
||||
return None
|
||||
|
||||
self.commits.set_options(
|
||||
days=days,
|
||||
prs_limit=prs_limit,
|
||||
commits_per_pr=commits_per_pr,
|
||||
include_files=include_files,
|
||||
)
|
||||
|
||||
return self.commits.get_user_commits()
|
||||
|
||||
def get_user_file_changes(self, days=30, prs_limit=5, commits_per_pr=3, skip_no_diff=True):
|
||||
"""Returns a flat list of file changes from PRs with associated commit information."""
|
||||
if not self.commits:
|
||||
return None
|
||||
|
||||
self.commits.set_options(
|
||||
days=days,
|
||||
prs_limit=prs_limit,
|
||||
commits_per_pr=commits_per_pr,
|
||||
include_files=True,
|
||||
skip_no_diff=skip_no_diff,
|
||||
)
|
||||
|
||||
return self.commits.get_user_file_changes()
|
||||
|
||||
def get_issue_comments(self, limit=10, include_issue_details=True):
|
||||
"""Fetches the most recent comments made by the user on issues and PRs across repositories."""
|
||||
if not self.comments:
|
||||
return None
|
||||
|
||||
self.comments.set_limit(
|
||||
limit=limit,
|
||||
include_issue_details=include_issue_details,
|
||||
)
|
||||
|
||||
return self.comments.get_issue_comments()
|
||||
|
||||
def _get_user(self, username):
|
||||
"""Fetches a GitHub user object."""
|
||||
try:
|
||||
return self.github.get_user(username)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to GitHub API: {e}")
|
||||
return None
|
||||
|
||||
def _extract_user_info(self):
|
||||
"""Extracts basic information from a GitHub user object."""
|
||||
return {
|
||||
"login": self.user.login,
|
||||
"name": self.user.name,
|
||||
"bio": self.user.bio,
|
||||
"company": self.user.company,
|
||||
"location": self.user.location,
|
||||
"public_repos": self.user.public_repos,
|
||||
"followers": self.user.followers,
|
||||
"following": self.user.following,
|
||||
}
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
import asyncio
|
||||
import cognee
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_dev_profile import GitHubDevProfile
|
||||
|
||||
|
||||
def get_github_profile_data(
|
||||
username, token=None, days=30, prs_limit=5, commits_per_pr=3, issues_limit=5, max_comments=3
|
||||
):
|
||||
"""Fetches comprehensive GitHub profile data including user info, commits from PRs, and comments."""
|
||||
token = token or ""
|
||||
profile = GitHubDevProfile(username, token)
|
||||
|
||||
if not profile.user:
|
||||
return None
|
||||
|
||||
commits_result = profile.get_user_commits(
|
||||
days=days, prs_limit=prs_limit, commits_per_pr=commits_per_pr, include_files=True
|
||||
)
|
||||
comments = profile.get_issue_comments(limit=max_comments, include_issue_details=True)
|
||||
|
||||
return {
|
||||
"user": profile.get_user_info(),
|
||||
"commits": commits_result["commits"] if commits_result else [],
|
||||
"comments": comments or [],
|
||||
}
|
||||
|
||||
|
||||
def get_github_file_changes(
|
||||
username, token=None, days=30, prs_limit=5, commits_per_pr=3, skip_no_diff=True
|
||||
):
|
||||
"""Fetches a flat list of file changes from PRs with associated commit information for a GitHub user."""
|
||||
token = token or ""
|
||||
profile = GitHubDevProfile(username, token)
|
||||
|
||||
if not profile.user:
|
||||
return None
|
||||
|
||||
file_changes = profile.get_user_file_changes(
|
||||
days=days, prs_limit=prs_limit, commits_per_pr=commits_per_pr, skip_no_diff=skip_no_diff
|
||||
)
|
||||
|
||||
return {"user": profile.get_user_info(), "file_changes": file_changes or []}
|
||||
|
||||
|
||||
def get_github_data_for_cognee(
|
||||
username,
|
||||
token=None,
|
||||
days=30,
|
||||
prs_limit=3,
|
||||
commits_per_pr=3,
|
||||
issues_limit=3,
|
||||
max_comments=3,
|
||||
skip_no_diff=True,
|
||||
):
|
||||
"""Fetches enriched GitHub data for a user with PR file changes and comments combined with user data."""
|
||||
token = token or ""
|
||||
profile = GitHubDevProfile(username, token)
|
||||
|
||||
if not profile.user:
|
||||
return None
|
||||
|
||||
user_info = profile.get_user_info()
|
||||
|
||||
file_changes = profile.get_user_file_changes(
|
||||
days=days, prs_limit=prs_limit, commits_per_pr=commits_per_pr, skip_no_diff=skip_no_diff
|
||||
)
|
||||
|
||||
enriched_file_changes = []
|
||||
if file_changes:
|
||||
enriched_file_changes = [item | user_info for item in file_changes]
|
||||
|
||||
comments = profile.get_issue_comments(limit=max_comments, include_issue_details=True)
|
||||
|
||||
enriched_comments = []
|
||||
if comments:
|
||||
enriched_comments = []
|
||||
for comment in comments:
|
||||
safe_user_info = {k: v for k, v in user_info.items() if k not in comment}
|
||||
enriched_comments.append(comment | safe_user_info)
|
||||
|
||||
return {"user": user_info, "file_changes": enriched_file_changes, "comments": enriched_comments}
|
||||
|
||||
|
||||
async def cognify_github_profile(username, token=None):
|
||||
"""Ingests GitHub data into Cognee with soft and technical node sets."""
|
||||
github_data = get_github_data_for_cognee(username=username, token=token)
|
||||
if not github_data:
|
||||
return False
|
||||
|
||||
await cognee.add(
|
||||
json.dumps(github_data["user"], default=str), node_set=["soft", "technical", username]
|
||||
)
|
||||
|
||||
for comment in github_data["comments"]:
|
||||
await cognee.add(
|
||||
"Comment: " + json.dumps(comment, default=str), node_set=["soft", username]
|
||||
)
|
||||
|
||||
for file_change in github_data["file_changes"]:
|
||||
await cognee.add(
|
||||
"File Change: " + json.dumps(file_change, default=str), node_set=["technical", username]
|
||||
)
|
||||
|
||||
await cognee.cognify()
|
||||
return True
|
||||
|
||||
|
||||
async def main(username):
|
||||
"""Main function for testing Cognee ingest."""
|
||||
import os
|
||||
import dotenv
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
|
||||
dotenv.load_dotenv()
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
await cognify_github_profile(username, token)
|
||||
|
||||
# success = await cognify_github_profile(username, token)
|
||||
|
||||
# if success:
|
||||
# visualization_path = os.path.join(os.path.dirname(__file__), "./.artifacts/github_graph.html")
|
||||
# await visualize_graph(visualization_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
username = ""
|
||||
asyncio.run(main(username))
|
||||
# token = os.getenv("GITHUB_TOKEN")
|
||||
# github_data = get_github_data_for_cognee(username=username, token=token)
|
||||
# print(json.dumps(github_data, indent=2, default=str))
|
||||
|
|
@ -0,0 +1,300 @@
|
|||
import json
|
||||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
from cognee.api.v1.search import SearchType
|
||||
import cognee
|
||||
from cognee.low_level import DataPoint, setup as cognee_setup
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_ingest import (
|
||||
get_github_data_for_cognee,
|
||||
)
|
||||
|
||||
# Import DataPoint classes from github_datapoints.py
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_datapoints import (
|
||||
GitHubUser,
|
||||
Repository,
|
||||
File,
|
||||
Commit,
|
||||
)
|
||||
|
||||
# Import creator functions from github_datapoint_creators.py
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_datapoint_creators import (
|
||||
create_github_user_datapoint,
|
||||
create_repository_datapoint,
|
||||
create_file_datapoint,
|
||||
create_commit_datapoint,
|
||||
create_file_change_datapoint,
|
||||
create_issue_datapoint,
|
||||
create_comment_datapoint,
|
||||
)
|
||||
|
||||
logger = get_logger("github_ingest")
|
||||
|
||||
|
||||
def collect_repositories(
|
||||
section: List[Dict[str, Any]],
|
||||
repositories: Dict[str, Repository],
|
||||
user: GitHubUser,
|
||||
nodesets: List[NodeSet],
|
||||
) -> None:
|
||||
"""Collect unique repositories from a data section and register them to the user."""
|
||||
for entry in section:
|
||||
repo_name = entry.get("repo", "")
|
||||
if not repo_name or repo_name in repositories:
|
||||
continue
|
||||
repo = create_repository_datapoint(repo_name, nodesets)
|
||||
repositories[repo_name] = repo
|
||||
user.interacts_with.append(repo)
|
||||
|
||||
|
||||
def get_or_create_repository(
|
||||
repo_name: str, repositories: Dict[str, Repository], user: GitHubUser, nodesets: List[NodeSet]
|
||||
) -> Repository:
|
||||
if repo_name in repositories:
|
||||
return repositories[repo_name]
|
||||
repo = create_repository_datapoint(repo_name, nodesets)
|
||||
repositories[repo_name] = repo
|
||||
user.interacts_with.append(repo)
|
||||
return repo
|
||||
|
||||
|
||||
def get_or_create_file(
|
||||
filename: str,
|
||||
repo_name: str,
|
||||
files: Dict[str, File],
|
||||
technical_nodeset: NodeSet,
|
||||
) -> File:
|
||||
file_key = f"{repo_name}:{filename}"
|
||||
if file_key in files:
|
||||
return files[file_key]
|
||||
file = create_file_datapoint(filename, repo_name, [technical_nodeset])
|
||||
files[file_key] = file
|
||||
return file
|
||||
|
||||
|
||||
def get_or_create_commit(
|
||||
commit_data: Dict[str, Any],
|
||||
user: GitHubUser,
|
||||
commits: Dict[str, Commit],
|
||||
repository: Repository,
|
||||
technical_nodeset: NodeSet,
|
||||
) -> Commit:
|
||||
commit_sha = commit_data.get("commit_sha", "")
|
||||
if commit_sha in commits:
|
||||
return commits[commit_sha]
|
||||
commit = create_commit_datapoint(commit_data, user, [technical_nodeset])
|
||||
commits[commit_sha] = commit
|
||||
link_commit_to_repo(commit, repository)
|
||||
return commit
|
||||
|
||||
|
||||
def link_file_to_repo(file: File, repository: Repository):
|
||||
if file not in repository.contains:
|
||||
repository.contains.append(file)
|
||||
|
||||
|
||||
def link_commit_to_repo(commit: Commit, repository: Repository):
|
||||
if commit not in repository.has_commit:
|
||||
repository.has_commit.append(commit)
|
||||
|
||||
|
||||
def process_file_changes_data(
|
||||
github_data: Dict[str, Any],
|
||||
user: GitHubUser,
|
||||
repositories: Dict[str, Repository],
|
||||
technical_nodeset: NodeSet,
|
||||
) -> List[DataPoint]:
|
||||
"""Process file changes data and build the graph structure with stronger connections."""
|
||||
file_changes = github_data.get("file_changes", [])
|
||||
if not file_changes:
|
||||
return []
|
||||
|
||||
collect_repositories(file_changes, repositories, user, [technical_nodeset])
|
||||
|
||||
files = {}
|
||||
commits = {}
|
||||
file_changes_list = []
|
||||
for fc_data in file_changes:
|
||||
repo_name = fc_data.get("repo", "")
|
||||
filename = fc_data.get("filename", "")
|
||||
commit_sha = fc_data.get("commit_sha", "")
|
||||
if not repo_name or not filename or not commit_sha:
|
||||
continue
|
||||
repository = get_or_create_repository(repo_name, repositories, user, [technical_nodeset])
|
||||
file = get_or_create_file(filename, repo_name, files, technical_nodeset)
|
||||
commit = get_or_create_commit(fc_data, user, commits, repository, technical_nodeset)
|
||||
file_change = create_file_change_datapoint(fc_data, user, file, [technical_nodeset])
|
||||
file_changes_list.append(file_change)
|
||||
if file_change not in commit.has_change:
|
||||
commit.has_change.append(file_change)
|
||||
all_datapoints = list(commits.values()) + file_changes_list
|
||||
return all_datapoints
|
||||
|
||||
|
||||
def process_comments_data(
|
||||
github_data: Dict[str, Any],
|
||||
user: GitHubUser,
|
||||
repositories: Dict[str, Repository],
|
||||
technical_nodeset: NodeSet,
|
||||
soft_nodeset: NodeSet,
|
||||
) -> List[DataPoint]:
|
||||
"""Process comments data and build the graph structure with stronger connections."""
|
||||
comments_data = github_data.get("comments", [])
|
||||
if not comments_data:
|
||||
return []
|
||||
|
||||
collect_repositories(comments_data, repositories, user, [soft_nodeset])
|
||||
|
||||
issues = {}
|
||||
comments_list = []
|
||||
for comment_data in comments_data:
|
||||
repo_name = comment_data.get("repo", "")
|
||||
issue_number = comment_data.get("issue_number", 0)
|
||||
if not repo_name or not issue_number:
|
||||
continue
|
||||
repository = get_or_create_repository(repo_name, repositories, user, [soft_nodeset])
|
||||
issue_key = f"{repo_name}:{issue_number}"
|
||||
if issue_key not in issues:
|
||||
issue = create_issue_datapoint(comment_data, repo_name, [soft_nodeset])
|
||||
issues[issue_key] = issue
|
||||
if issue not in repository.has_issue:
|
||||
repository.has_issue.append(issue)
|
||||
comment = create_comment_datapoint(comment_data, user, [soft_nodeset])
|
||||
comments_list.append(comment)
|
||||
if comment not in issues[issue_key].has_comment:
|
||||
issues[issue_key].has_comment.append(comment)
|
||||
all_datapoints = list(issues.values()) + comments_list
|
||||
return all_datapoints
|
||||
|
||||
|
||||
def build_github_datapoints_from_dict(github_data: Dict[str, Any]):
|
||||
"""Builds all DataPoints from a GitHub data dictionary."""
|
||||
if not github_data or "user" not in github_data:
|
||||
return None
|
||||
|
||||
soft_nodeset = NodeSet(id=uuid5(NAMESPACE_OID, "NodeSet:soft"), name="soft")
|
||||
technical_nodeset = NodeSet(id=uuid5(NAMESPACE_OID, "NodeSet:technical"), name="technical")
|
||||
|
||||
datapoints = create_github_user_datapoint(
|
||||
github_data["user"], [soft_nodeset, technical_nodeset]
|
||||
)
|
||||
if not datapoints:
|
||||
return None
|
||||
user = datapoints[0]
|
||||
|
||||
repositories = {}
|
||||
|
||||
file_change_datapoints = process_file_changes_data(
|
||||
github_data, user, repositories, technical_nodeset
|
||||
)
|
||||
comment_datapoints = process_comments_data(
|
||||
github_data, user, repositories, technical_nodeset, soft_nodeset
|
||||
)
|
||||
|
||||
all_datapoints = (
|
||||
datapoints + list(repositories.values()) + file_change_datapoints + comment_datapoints
|
||||
)
|
||||
return all_datapoints
|
||||
|
||||
|
||||
async def cognify_github_data(github_data: dict):
|
||||
"""Process GitHub user, file changes, and comments data from a loaded dictionary."""
|
||||
all_datapoints = build_github_datapoints_from_dict(github_data)
|
||||
if not all_datapoints:
|
||||
logger.error("Failed to create datapoints")
|
||||
return False
|
||||
|
||||
dataset_id = uuid5(NAMESPACE_OID, "GitHub")
|
||||
|
||||
cognee_user = await get_default_user()
|
||||
tasks = [Task(add_data_points, task_config={"batch_size": 50})]
|
||||
results = run_tasks(
|
||||
tasks=tasks,
|
||||
data=all_datapoints,
|
||||
dataset_id=dataset_id,
|
||||
pipeline_name="github_pipeline",
|
||||
user=cognee_user,
|
||||
)
|
||||
async for result in results:
|
||||
print(result)
|
||||
|
||||
logger.info(f"Done processing {len(all_datapoints)} datapoints")
|
||||
return True
|
||||
|
||||
|
||||
async def cognify_github_data_from_username(
|
||||
username: str,
|
||||
token: Optional[str] = None,
|
||||
days: int = 30,
|
||||
prs_limit: int = 3,
|
||||
commits_per_pr: int = 3,
|
||||
issues_limit: int = 3,
|
||||
max_comments: int = 3,
|
||||
skip_no_diff: bool = True,
|
||||
):
|
||||
"""Fetches GitHub data for a username and processes it through the DataPoint pipeline."""
|
||||
|
||||
logger.info(f"Fetching GitHub data for user: {username}")
|
||||
|
||||
github_data = get_github_data_for_cognee(
|
||||
username=username,
|
||||
token=token,
|
||||
days=days,
|
||||
prs_limit=prs_limit,
|
||||
commits_per_pr=commits_per_pr,
|
||||
issues_limit=issues_limit,
|
||||
max_comments=max_comments,
|
||||
skip_no_diff=skip_no_diff,
|
||||
)
|
||||
|
||||
if not github_data:
|
||||
logger.error(f"Failed to fetch GitHub data for user: {username}")
|
||||
return False
|
||||
|
||||
github_data = json.loads(json.dumps(github_data, default=str))
|
||||
|
||||
await cognify_github_data(github_data)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def process_github_from_file(json_file_path: str):
|
||||
"""Process GitHub data from a JSON file."""
|
||||
logger.info(f"Processing GitHub data from file: {json_file_path}")
|
||||
try:
|
||||
with open(json_file_path, "r") as f:
|
||||
github_data = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JSON file: {e}")
|
||||
return False
|
||||
|
||||
return await cognify_github_data(github_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
username = ""
|
||||
|
||||
async def cognify_from_username(username, token):
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
await cognify_github_data_from_username(username, token)
|
||||
|
||||
# Run it
|
||||
asyncio.run(cognify_from_username(username, token))
|
||||
179
cognee/complex_demos/crewai_demo/src/crewai_demo/hiring_crew.py
Normal file
179
cognee/complex_demos/crewai_demo/src/crewai_demo/hiring_crew.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task, before_kickoff
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.custom_tools.cognee_ingestion import (
|
||||
CogneeIngestion,
|
||||
)
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.custom_tools.cognee_search import CogneeSearch
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
|
||||
|
||||
@CrewBase
|
||||
class HiringCrew:
|
||||
agents_config = "config/agents.yaml"
|
||||
tasks_config = "config/tasks.yaml"
|
||||
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
self
|
||||
|
||||
@agent
|
||||
def soft_skills_expert_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["soft_skills_expert_agent"],
|
||||
tools=[CogneeSearch(nodeset_name="soft")],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@agent
|
||||
def technical_expert_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["technical_expert_agent"],
|
||||
tools=[CogneeSearch(nodeset_name="technical")],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@agent
|
||||
def decision_maker_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["decision_maker_agent"],
|
||||
tools=[CogneeIngestion(nodeset_name="final_report")],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@task
|
||||
def soft_skills_assessment_applicant1_task(self) -> Task:
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["description"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["expected_output"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["soft_skills_assessment_applicant1_task"],
|
||||
async_execution=False,
|
||||
)
|
||||
|
||||
@task
|
||||
def soft_skills_assessment_applicant2_task(self) -> Task:
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["description"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["expected_output"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["soft_skills_assessment_applicant2_task"],
|
||||
async_execution=False,
|
||||
)
|
||||
|
||||
@task
|
||||
def technical_assessment_applicant1_task(self) -> Task:
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["description"] = (
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["expected_output"] = (
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["technical_assessment_applicant1_task"], async_execution=False
|
||||
)
|
||||
|
||||
@task
|
||||
def technical_assessment_applicant2_task(self) -> Task:
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["description"] = (
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["expected_output"] = (
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["technical_assessment_applicant2_task"], async_execution=False
|
||||
)
|
||||
|
||||
@task
|
||||
def hiring_decision_task(self) -> Task:
|
||||
self.tasks_config["hiring_decision_task"]["description"] = self.tasks_config[
|
||||
"hiring_decision_task"
|
||||
]["description"].format(**self.inputs)
|
||||
self.tasks_config["hiring_decision_task"]["expected_output"] = self.tasks_config[
|
||||
"hiring_decision_task"
|
||||
]["expected_output"].format(**self.inputs)
|
||||
return Task(config=self.tasks_config["hiring_decision_task"], async_execution=False)
|
||||
|
||||
@task
|
||||
def ingest_hiring_decision_task(self) -> Task:
|
||||
self.tasks_config["ingest_hiring_decision_task"]["description"] = self.tasks_config[
|
||||
"ingest_hiring_decision_task"
|
||||
]["description"].format(**self.inputs)
|
||||
self.tasks_config["ingest_hiring_decision_task"]["expected_output"] = self.tasks_config[
|
||||
"ingest_hiring_decision_task"
|
||||
]["expected_output"].format(**self.inputs)
|
||||
return Task(
|
||||
config=self.tasks_config["ingest_hiring_decision_task"],
|
||||
async_execution=False,
|
||||
)
|
||||
|
||||
def refine_agent_configs(self, agent_name: str = None):
|
||||
system_prompt = (
|
||||
"You are an expert in improving agent definitions for autonomous AI systems. "
|
||||
"Given an agent's role, goal, and backstory, refine them to be:\n"
|
||||
"- Concise and well-written\n"
|
||||
"- Aligned with the agent’s function\n"
|
||||
"- Clear and professional\n"
|
||||
"- Consistent with multi-agent teamwork\n\n"
|
||||
"Return the updated definition as a JSON object with keys: role, goal, backstory."
|
||||
)
|
||||
|
||||
agent_keys = [agent_name] if agent_name else self.agents_config.keys()
|
||||
|
||||
for name in agent_keys:
|
||||
agent_def = self.agents_config[name]
|
||||
|
||||
user_prompt = f"""Here is the current agent definition:
|
||||
role: {agent_def["role"]}
|
||||
goal: {agent_def["goal"]}
|
||||
backstory: {agent_def["backstory"]}
|
||||
|
||||
Please improve it."""
|
||||
llm_client = get_llm_client()
|
||||
improved = llm_client.create_structured_output(
|
||||
text_input=user_prompt, system_prompt=system_prompt, response_model=AgentConfig
|
||||
)
|
||||
|
||||
self.agents_config[name] = improved.dict()
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
share_crew=True,
|
||||
output_log_file="hiring_crew_log.txt",
|
||||
)
|
||||
53
cognee/complex_demos/crewai_demo/src/crewai_demo/main.py
Normal file
53
cognee/complex_demos/crewai_demo/src/crewai_demo/main.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import warnings
|
||||
import os
|
||||
from hiring_crew import HiringCrew
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.custom_tools.github_ingestion import (
|
||||
GithubIngestion,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
|
||||
def print_environment():
|
||||
for key in sorted(os.environ):
|
||||
print(f"{key}={os.environ[key]}")
|
||||
|
||||
|
||||
def run_github_ingestion(applicant_1, applicant_2):
|
||||
GithubIngestion().run(applicant_1=applicant_1, applicant_2=applicant_2)
|
||||
|
||||
|
||||
def run_hiring_crew(applicants: dict, number_of_rounds: int = 1, llm_client=None):
|
||||
for hiring_round in range(number_of_rounds):
|
||||
print(f"\nStarting hiring round {hiring_round + 1}...\n")
|
||||
crew = HiringCrew(inputs=applicants)
|
||||
if hiring_round > 0:
|
||||
print("Refining agent prompts for this round...")
|
||||
crew.refine_agent_configs(agent_name="soft_skills_expert_agent")
|
||||
crew.refine_agent_configs(agent_name="technical_expert_agent")
|
||||
crew.refine_agent_configs(agent_name="decision_maker_agent")
|
||||
|
||||
crew.crew().kickoff()
|
||||
|
||||
|
||||
def run(enable_ingestion=True, enable_crew=True):
|
||||
try:
|
||||
print_environment()
|
||||
|
||||
applicants = {"applicant_1": "hajdul88", "applicant_2": "lxobr"}
|
||||
|
||||
if enable_ingestion:
|
||||
run_github_ingestion(applicants["applicant_1"], applicants["applicant_2"])
|
||||
|
||||
if enable_crew:
|
||||
run_hiring_crew(applicants=applicants, number_of_rounds=5)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while running the process: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
enable_ingestion = True
|
||||
enable_crew = True
|
||||
|
||||
run(enable_ingestion=enable_ingestion, enable_crew=enable_crew)
|
||||
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
from functools import wraps
|
||||
from abc import abstractmethod, ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from typing import Protocol, Optional, Dict, Any, List, Type, Tuple
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -189,3 +189,6 @@ class GraphDBInterface(ABC):
|
|||
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
||||
"""Get all nodes connected to a given node with their relationships."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_nodeset_subgraph(self, node_type, node_name):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -728,6 +728,66 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to get graph data: {e}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
primary_query = """
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:Node)
|
||||
WHERE n.type = $label AND n.name = wantedName
|
||||
RETURN DISTINCT n.id
|
||||
"""
|
||||
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
|
||||
primary_ids = [row[0] for row in primary_rows]
|
||||
if not primary_ids:
|
||||
return [], []
|
||||
|
||||
neighbor_query = """
|
||||
MATCH (n:Node)-[:EDGE]-(nbr:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN DISTINCT nbr.id
|
||||
"""
|
||||
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
|
||||
neighbor_ids = [row[0] for row in nbr_rows]
|
||||
|
||||
all_ids = list({*primary_ids, *neighbor_ids})
|
||||
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN n.id, n.name, n.type, n.properties
|
||||
"""
|
||||
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
||||
nodes: List[Tuple[str, dict]] = []
|
||||
for node_id, name, typ, props in node_rows:
|
||||
data = {"id": node_id, "name": name, "type": typ}
|
||||
if props:
|
||||
try:
|
||||
data.update(json.loads(props))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for node {node_id}")
|
||||
nodes.append((node_id, data))
|
||||
|
||||
edges_query = """
|
||||
MATCH (a:Node)-[r:EDGE]-(b:Node)
|
||||
WHERE a.id IN $ids AND b.id IN $ids
|
||||
RETURN a.id, b.id, r.relationship_name, r.properties
|
||||
"""
|
||||
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
||||
edges: List[Tuple[str, str, str, dict]] = []
|
||||
for from_id, to_id, rel_type, props in edge_rows:
|
||||
data = {}
|
||||
if props:
|
||||
try:
|
||||
data = json.loads(props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
||||
|
||||
edges.append((from_id, to_id, rel_type, data))
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import json
|
|||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -517,6 +517,58 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"names": node_name})
|
||||
if not result:
|
||||
return [], []
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
|
||||
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
|
||||
edges = [
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
)
|
||||
for r in raw_rels
|
||||
]
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
|
|
|||
|
|
@ -250,14 +250,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
if len(vector_list) == 0:
|
||||
return []
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("_distance"))
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -63,8 +63,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -91,8 +90,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
|
|
|
|||
|
|
@ -5,4 +5,3 @@ class NodeSet(DataPoint):
|
|||
"""NodeSet data point."""
|
||||
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Dict, Union
|
||||
from typing import List, Dict, Union, Optional, Type
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
|
|
@ -61,22 +61,27 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node_dimension=1,
|
||||
edge_dimension=1,
|
||||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
if len(memory_fragment_filter) == 0:
|
||||
if node_type is not None and node_name is not None:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
|
||||
if not nodes_data:
|
||||
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
||||
if not edges_data:
|
||||
raise EntityNotFoundError(message="No edge data retrieved from the database.")
|
||||
if not nodes_data or not edges_data:
|
||||
logger.warning("Empty projected graph.")
|
||||
return None
|
||||
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
|
|
|
|||
|
|
@ -144,6 +144,7 @@ def expand_with_nodes_and_edges(
|
|||
is_a=type_node,
|
||||
description=node.description,
|
||||
ontology_valid=ontology_validated_source_ent,
|
||||
belongs_to_set=data_chunk.belongs_to_set,
|
||||
)
|
||||
|
||||
added_nodes_map[entity_node_key] = entity_node
|
||||
|
|
|
|||
|
|
@ -24,3 +24,13 @@ class CypherSearchError(CogneeApiError):
|
|||
|
||||
class NoDataError(CriticalError):
|
||||
message: str = "No data found in the system, please add data first."
|
||||
|
||||
|
||||
class CollectionDistancesNotFoundError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "No collection distances found for the given query.",
|
||||
name: str = "CollectionDistancesNotFoundError",
|
||||
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
|
|
@ -8,6 +8,9 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
|
|
@ -18,11 +21,15 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
|
||||
def _get_nodes(self, retrieved_edges: list) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
|
|
@ -68,7 +75,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query, top_k=self.top_k, collections=vector_index_collections or None
|
||||
query,
|
||||
top_k=self.top_k,
|
||||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
|
@ -78,6 +89,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
triplets = await self.get_triplets(query)
|
||||
|
||||
if len(triplets) == 0:
|
||||
logger.warning("Empty context was provided to the completion")
|
||||
return ""
|
||||
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
|
@ -55,6 +55,8 @@ def format_triplets(edges):
|
|||
|
||||
async def get_memory_fragment(
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
graph_engine = await get_graph_engine()
|
||||
|
|
@ -68,6 +70,8 @@ async def get_memory_fragment(
|
|||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
except EntityNotFoundError:
|
||||
pass
|
||||
|
|
@ -82,6 +86,8 @@ async def brute_force_triplet_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
|
@ -93,6 +99,8 @@ async def brute_force_triplet_search(
|
|||
collections=collections,
|
||||
properties_to_project=properties_to_project,
|
||||
memory_fragment=memory_fragment,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
return retrieved_results
|
||||
|
||||
|
|
@ -104,6 +112,8 @@ async def brute_force_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -115,6 +125,8 @@ async def brute_force_search(
|
|||
collections (Optional[List[str]]): List of collections to query.
|
||||
properties_to_project (Optional[List[str]]): List of properties to project.
|
||||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
|
|
@ -125,7 +137,9 @@ async def brute_force_search(
|
|||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(properties_to_project)
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, Type, List
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
|
|
@ -11,6 +11,7 @@ from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionR
|
|||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
|
|
@ -29,12 +30,20 @@ async def search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
):
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(
|
||||
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
|
||||
query_type,
|
||||
query_text,
|
||||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
filtered_search_results = []
|
||||
|
|
@ -57,6 +66,8 @@ async def specific_search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||
|
|
@ -69,6 +80,8 @@ async def specific_search(
|
|||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path, top_k=top_k
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue