Pagination for get by group_id (#218)
* add pagination to subgraphs * update pagination * update LiteralString import * cleanup * cleanup * update embedding dims
This commit is contained in:
parent
397291de4b
commit
0fbe5c0704
12 changed files with 123 additions and 35 deletions
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -36,7 +35,7 @@ def parse_timestamp(timestamp: str) -> timedelta:
|
||||||
return timedelta() # Return 0 duration if parsing fails
|
return timedelta() # Return 0 duration if parsing fails
|
||||||
|
|
||||||
|
|
||||||
def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]:
|
def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]:
|
||||||
with open(file_path) as file:
|
with open(file_path) as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
|
@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
||||||
|
|
||||||
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||||
if not passages:
|
if not passages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderClient(ABC):
|
class CrossEncoderClient(ABC):
|
||||||
|
|
@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||||
"""
|
"""
|
||||||
Rank the given passages based on their relevance to the query.
|
Rank the given passages based on their relevance to the query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): The query string.
|
query (str): The query string.
|
||||||
passages (List[str]): A list of passages to rank.
|
passages (list[str]): A list of passages to rank.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
|
List[tuple[str, float]]: A list of tuples containing the passage and its score,
|
||||||
sorted in descending order of relevance.
|
sorted in descending order of relevance.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -23,10 +23,11 @@ from uuid import uuid4
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
COMMUNITY_EDGE_SAVE,
|
COMMUNITY_EDGE_SAVE,
|
||||||
ENTITY_EDGE_SAVE,
|
ENTITY_EDGE_SAVE,
|
||||||
|
|
@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
|
||||||
async def delete(self, driver: AsyncDriver):
|
async def delete(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n)-[e {uuid: $uuid}]->(m)
|
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||||
DELETE e
|
DELETE e
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
|
|
@ -137,19 +138,34 @@ class EpisodicEdge(Edge):
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
async def get_by_group_ids(
|
||||||
|
cls,
|
||||||
|
driver: AsyncDriver,
|
||||||
|
group_ids: list[str],
|
||||||
|
limit: int = DEFAULT_PAGE_LIMIT,
|
||||||
|
created_at: datetime | None = None,
|
||||||
|
):
|
||||||
|
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
|
"""
|
||||||
|
+ cursor_query
|
||||||
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
e.uuid As uuid,
|
||||||
e.group_id AS group_id,
|
e.group_id AS group_id,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
m.uuid AS target_node_uuid,
|
m.uuid AS target_node_uuid,
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
|
ORDER BY e.uuid DESC
|
||||||
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
created_at=created_at,
|
||||||
|
limit=limit,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -274,11 +290,22 @@ class EntityEdge(Edge):
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
async def get_by_group_ids(
|
||||||
|
cls,
|
||||||
|
driver: AsyncDriver,
|
||||||
|
group_ids: list[str],
|
||||||
|
limit: int = DEFAULT_PAGE_LIMIT,
|
||||||
|
created_at: datetime | None = None,
|
||||||
|
):
|
||||||
|
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
|
"""
|
||||||
|
+ cursor_query
|
||||||
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid AS uuid,
|
e.uuid AS uuid,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
|
|
@ -292,8 +319,12 @@ class EntityEdge(Edge):
|
||||||
e.expired_at AS expired_at,
|
e.expired_at AS expired_at,
|
||||||
e.valid_at AS valid_at,
|
e.valid_at AS valid_at,
|
||||||
e.invalid_at AS invalid_at
|
e.invalid_at AS invalid_at
|
||||||
|
ORDER BY e.uuid DESC
|
||||||
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
created_at=created_at,
|
||||||
|
limit=limit,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -365,19 +396,34 @@ class CommunityEdge(Edge):
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
async def get_by_group_ids(
|
||||||
|
cls,
|
||||||
|
driver: AsyncDriver,
|
||||||
|
group_ids: list[str],
|
||||||
|
limit: int = DEFAULT_PAGE_LIMIT,
|
||||||
|
created_at: datetime | None = None,
|
||||||
|
):
|
||||||
|
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
|
"""
|
||||||
|
+ cursor_query
|
||||||
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
e.uuid As uuid,
|
||||||
e.group_id AS group_id,
|
e.group_id AS group_id,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
m.uuid AS target_node_uuid,
|
m.uuid AS target_node_uuid,
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
|
ORDER BY e.uuid DESC
|
||||||
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
created_at=created_at,
|
||||||
|
limit=limit,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Iterable, List, Literal
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
|
||||||
|
|
||||||
|
|
||||||
class EmbedderConfig(BaseModel):
|
class EmbedderConfig(BaseModel):
|
||||||
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
|
embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
|
||||||
|
|
||||||
|
|
||||||
class EmbedderClient(ABC):
|
class EmbedderClient(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create(
|
async def create(
|
||||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Iterable, List
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types import EmbeddingModel
|
from openai.types import EmbeddingModel
|
||||||
|
|
@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
result = await self.client.embeddings.create(
|
result = await self.client.embeddings.create(
|
||||||
input=input_data, model=self.config.embedding_model
|
input=input_data, model=self.config.embedding_model
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Iterable, List
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import voyageai # type: ignore
|
import voyageai # type: ignore
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
|
||||||
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
if isinstance(input_data, str):
|
if isinstance(input_data, str):
|
||||||
input_list = [input_data]
|
input_list = [input_data]
|
||||||
elif isinstance(input_data, List):
|
elif isinstance(input_data, list):
|
||||||
input_list = [str(i) for i in input_data if i]
|
input_list = [str(i) for i in input_data if i]
|
||||||
else:
|
else:
|
||||||
input_list = [str(i) for i in input_data if i is not None]
|
input_list = [str(i) for i in input_data if i is not None]
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ load_dotenv()
|
||||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||||
MAX_REFLEXION_ITERATIONS = 2
|
MAX_REFLEXION_ITERATIONS = 2
|
||||||
|
DEFAULT_PAGE_LIMIT = 20
|
||||||
|
|
||||||
|
|
||||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,11 @@ from uuid import uuid4
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import NodeNotFoundError
|
from graphiti_core.errors import NodeNotFoundError
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
COMMUNITY_NODE_SAVE,
|
COMMUNITY_NODE_SAVE,
|
||||||
ENTITY_NODE_SAVE,
|
ENTITY_NODE_SAVE,
|
||||||
|
|
@ -207,10 +208,21 @@ class EpisodicNode(Node):
|
||||||
return episodes
|
return episodes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
async def get_by_group_ids(
|
||||||
|
cls,
|
||||||
|
driver: AsyncDriver,
|
||||||
|
group_ids: list[str],
|
||||||
|
limit: int = DEFAULT_PAGE_LIMIT,
|
||||||
|
created_at: datetime | None = None,
|
||||||
|
):
|
||||||
|
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
||||||
|
"""
|
||||||
|
+ cursor_query
|
||||||
|
+ """
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
e.content AS content,
|
e.content AS content,
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
|
|
@ -220,8 +232,12 @@ class EpisodicNode(Node):
|
||||||
e.group_id AS group_id,
|
e.group_id AS group_id,
|
||||||
e.source_description AS source_description,
|
e.source_description AS source_description,
|
||||||
e.source AS source
|
e.source AS source
|
||||||
|
ORDER BY e.uuid DESC
|
||||||
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
created_at=created_at,
|
||||||
|
limit=limit,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -308,10 +324,21 @@ class EntityNode(Node):
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
async def get_by_group_ids(
|
||||||
|
cls,
|
||||||
|
driver: AsyncDriver,
|
||||||
|
group_ids: list[str],
|
||||||
|
limit: int = DEFAULT_PAGE_LIMIT,
|
||||||
|
created_at: datetime | None = None,
|
||||||
|
):
|
||||||
|
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
||||||
|
"""
|
||||||
|
+ cursor_query
|
||||||
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
|
|
@ -319,8 +346,12 @@ class EntityNode(Node):
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
|
ORDER BY n.uuid DESC
|
||||||
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
created_at=created_at,
|
||||||
|
limit=limit,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -407,10 +438,21 @@ class CommunityNode(Node):
|
||||||
return communities
|
return communities
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
async def get_by_group_ids(
|
||||||
|
cls,
|
||||||
|
driver: AsyncDriver,
|
||||||
|
group_ids: list[str],
|
||||||
|
limit: int = DEFAULT_PAGE_LIMIT,
|
||||||
|
created_at: datetime | None = None,
|
||||||
|
):
|
||||||
|
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
||||||
|
"""
|
||||||
|
+ cursor_query
|
||||||
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
|
|
@ -418,8 +460,12 @@ class CommunityNode(Node):
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
|
ORDER BY n.uuid DESC
|
||||||
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
created_at=created_at,
|
||||||
|
limit=limit,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ from graphiti_core.nodes import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 10
|
||||||
DEFAULT_MIN_SCORE = 0.6
|
DEFAULT_MIN_SCORE = 0.6
|
||||||
DEFAULT_MMR_LAMBDA = 0.5
|
DEFAULT_MMR_LAMBDA = 0.5
|
||||||
MAX_SEARCH_DEPTH = 3
|
MAX_SEARCH_DEPTH = 3
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
||||||
|
|
@ -34,11 +33,11 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def build_episodic_edges(
|
def build_episodic_edges(
|
||||||
entity_nodes: List[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
created_at: datetime,
|
created_at: datetime,
|
||||||
) -> List[EpisodicEdge]:
|
) -> list[EpisodicEdge]:
|
||||||
edges: List[EpisodicEdge] = [
|
edges: list[EpisodicEdge] = [
|
||||||
EpisodicEdge(
|
EpisodicEdge(
|
||||||
source_node_uuid=episode.uuid,
|
source_node_uuid=episode.uuid,
|
||||||
target_node_uuid=node.uuid,
|
target_node_uuid=node.uuid,
|
||||||
|
|
@ -52,11 +51,11 @@ def build_episodic_edges(
|
||||||
|
|
||||||
|
|
||||||
def build_community_edges(
|
def build_community_edges(
|
||||||
entity_nodes: List[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
community_node: CommunityNode,
|
community_node: CommunityNode,
|
||||||
created_at: datetime,
|
created_at: datetime,
|
||||||
) -> List[CommunityEdge]:
|
) -> list[CommunityEdge]:
|
||||||
edges: List[CommunityEdge] = [
|
edges: list[CommunityEdge] = [
|
||||||
CommunityEdge(
|
CommunityEdge(
|
||||||
source_node_uuid=community_node.uuid,
|
source_node_uuid=community_node.uuid,
|
||||||
target_node_uuid=node.uuid,
|
target_node_uuid=node.uuid,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
|
@ -31,7 +30,7 @@ async def extract_edge_dates(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
edge: EntityEdge,
|
edge: EntityEdge,
|
||||||
current_episode: EpisodicNode,
|
current_episode: EpisodicNode,
|
||||||
previous_episodes: List[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> tuple[datetime | None, datetime | None]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
context = {
|
context = {
|
||||||
'edge_fact': edge.fact,
|
'edge_fact': edge.fact,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue