Pagination for get by group_id (#218)

* add pagination to subgraphs

* update pagination

* update LiteralString import

* cleanup

* cleanup

* update embedding dims
This commit is contained in:
Preston Rasmussen 2024-12-02 11:17:37 -05:00 committed by GitHub
parent 397291de4b
commit 0fbe5c0704
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 123 additions and 35 deletions

View file

@ -1,7 +1,6 @@
import os import os
import re import re
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import List
from pydantic import BaseModel from pydantic import BaseModel
@ -36,7 +35,7 @@ def parse_timestamp(timestamp: str) -> timedelta:
return timedelta() # Return 0 duration if parsing fails return timedelta() # Return 0 duration if parsing fails
def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]: def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]:
with open(file_path) as file: with open(file_path) as file:
content = file.read() content = file.read()

View file

@ -15,7 +15,6 @@ limitations under the License.
""" """
import asyncio import asyncio
from typing import List, Tuple
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
def __init__(self): def __init__(self):
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3') self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]: async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
if not passages: if not passages:
return [] return []

View file

@ -15,7 +15,6 @@ limitations under the License.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple
class CrossEncoderClient(ABC): class CrossEncoderClient(ABC):
@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
""" """
@abstractmethod @abstractmethod
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]: async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
""" """
Rank the given passages based on their relevance to the query. Rank the given passages based on their relevance to the query.
Args: Args:
query (str): The query string. query (str): The query string.
passages (List[str]): A list of passages to rank. passages (list[str]): A list of passages to rank.
Returns: Returns:
List[Tuple[str, float]]: A list of tuples containing the passage and its score, List[tuple[str, float]]: A list of tuples containing the passage and its score,
sorted in descending order of relevance. sorted in descending order of relevance.
""" """
pass pass

View file

@ -23,10 +23,11 @@ from uuid import uuid4
from neo4j import AsyncDriver from neo4j import AsyncDriver
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
from graphiti_core.models.edges.edge_db_queries import ( from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_SAVE, COMMUNITY_EDGE_SAVE,
ENTITY_EDGE_SAVE, ENTITY_EDGE_SAVE,
@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
async def delete(self, driver: AsyncDriver): async def delete(self, driver: AsyncDriver):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MATCH (n)-[e {uuid: $uuid}]->(m) MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e DELETE e
""", """,
uuid=self.uuid, uuid=self.uuid,
@ -137,19 +138,34 @@ class EpisodicEdge(Edge):
return edges return edges
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN RETURN
e.uuid As uuid, e.uuid As uuid,
e.group_id AS group_id, e.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
e.created_at AS created_at e.created_at AS created_at
ORDER BY e.uuid DESC
LIMIT $limit
""", """,
group_ids=group_ids, group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
routing_='r', routing_='r',
) )
@ -274,11 +290,22 @@ class EntityEdge(Edge):
return edges return edges
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN RETURN
e.uuid AS uuid, e.uuid AS uuid,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
@ -292,8 +319,12 @@ class EntityEdge(Edge):
e.expired_at AS expired_at, e.expired_at AS expired_at,
e.valid_at AS valid_at, e.valid_at AS valid_at,
e.invalid_at AS invalid_at e.invalid_at AS invalid_at
ORDER BY e.uuid DESC
LIMIT $limit
""", """,
group_ids=group_ids, group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
routing_='r', routing_='r',
) )
@ -365,19 +396,34 @@ class CommunityEdge(Edge):
return edges return edges
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN RETURN
e.uuid As uuid, e.uuid As uuid,
e.group_id AS group_id, e.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
e.created_at AS created_at e.created_at AS created_at
ORDER BY e.uuid DESC
LIMIT $limit
""", """,
group_ids=group_ids, group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
routing_='r', routing_='r',
) )

View file

@ -15,7 +15,7 @@ limitations under the License.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterable, List, Literal from collections.abc import Iterable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
class EmbedderConfig(BaseModel): class EmbedderConfig(BaseModel):
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True) embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
class EmbedderClient(ABC): class EmbedderClient(ABC):
@abstractmethod @abstractmethod
async def create( async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]] self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]: ) -> list[float]:
pass pass

View file

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from typing import Iterable, List from collections.abc import Iterable
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types import EmbeddingModel from openai.types import EmbeddingModel
@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def create( async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]] self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]: ) -> list[float]:
result = await self.client.embeddings.create( result = await self.client.embeddings.create(
input=input_data, model=self.config.embedding_model input=input_data, model=self.config.embedding_model

View file

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from typing import Iterable, List from collections.abc import Iterable
import voyageai # type: ignore import voyageai # type: ignore
from pydantic import Field from pydantic import Field
@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
self.client = voyageai.AsyncClient(api_key=config.api_key) self.client = voyageai.AsyncClient(api_key=config.api_key)
async def create( async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]] self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]: ) -> list[float]:
if isinstance(input_data, str): if isinstance(input_data, str):
input_list = [input_data] input_list = [input_data]
elif isinstance(input_data, List): elif isinstance(input_data, list):
input_list = [str(i) for i in input_data if i] input_list = [str(i) for i in input_data if i]
else: else:
input_list = [str(i) for i in input_data if i is not None] input_list = [str(i) for i in input_data if i is not None]

View file

@ -26,6 +26,7 @@ load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
MAX_REFLEXION_ITERATIONS = 2 MAX_REFLEXION_ITERATIONS = 2
DEFAULT_PAGE_LIMIT = 20
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:

View file

@ -24,10 +24,11 @@ from uuid import uuid4
from neo4j import AsyncDriver from neo4j import AsyncDriver
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
from graphiti_core.models.nodes.node_db_queries import ( from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE, COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE, ENTITY_NODE_SAVE,
@ -207,10 +208,21 @@ class EpisodicNode(Node):
return episodes return episodes
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (e:Episodic) WHERE e.group_id IN $group_ids MATCH (e:Episodic) WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN DISTINCT RETURN DISTINCT
e.content AS content, e.content AS content,
e.created_at AS created_at, e.created_at AS created_at,
@ -220,8 +232,12 @@ class EpisodicNode(Node):
e.group_id AS group_id, e.group_id AS group_id,
e.source_description AS source_description, e.source_description AS source_description,
e.source AS source e.source AS source
ORDER BY e.uuid DESC
LIMIT $limit
""", """,
group_ids=group_ids, group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
routing_='r', routing_='r',
) )
@ -308,10 +324,21 @@ class EntityNode(Node):
return nodes return nodes
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Entity) WHERE n.group_id IN $group_ids MATCH (n:Entity) WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN RETURN
n.uuid As uuid, n.uuid As uuid,
n.name AS name, n.name AS name,
@ -319,8 +346,12 @@ class EntityNode(Node):
n.group_id AS group_id, n.group_id AS group_id,
n.created_at AS created_at, n.created_at AS created_at,
n.summary AS summary n.summary AS summary
ORDER BY n.uuid DESC
LIMIT $limit
""", """,
group_ids=group_ids, group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
routing_='r', routing_='r',
) )
@ -407,10 +438,21 @@ class CommunityNode(Node):
return communities return communities
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Community) WHERE n.group_id IN $group_ids MATCH (n:Community) WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN RETURN
n.uuid As uuid, n.uuid As uuid,
n.name AS name, n.name AS name,
@ -418,8 +460,12 @@ class CommunityNode(Node):
n.group_id AS group_id, n.group_id AS group_id,
n.created_at AS created_at, n.created_at AS created_at,
n.summary AS summary n.summary AS summary
ORDER BY n.uuid DESC
LIMIT $limit
""", """,
group_ids=group_ids, group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
routing_='r', routing_='r',
) )

View file

@ -40,7 +40,7 @@ from graphiti_core.nodes import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3 RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6 DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5 DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3 MAX_SEARCH_DEPTH = 3

View file

@ -18,7 +18,6 @@ import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from time import time from time import time
from typing import List
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
@ -34,11 +33,11 @@ logger = logging.getLogger(__name__)
def build_episodic_edges( def build_episodic_edges(
entity_nodes: List[EntityNode], entity_nodes: list[EntityNode],
episode: EpisodicNode, episode: EpisodicNode,
created_at: datetime, created_at: datetime,
) -> List[EpisodicEdge]: ) -> list[EpisodicEdge]:
edges: List[EpisodicEdge] = [ edges: list[EpisodicEdge] = [
EpisodicEdge( EpisodicEdge(
source_node_uuid=episode.uuid, source_node_uuid=episode.uuid,
target_node_uuid=node.uuid, target_node_uuid=node.uuid,
@ -52,11 +51,11 @@ def build_episodic_edges(
def build_community_edges( def build_community_edges(
entity_nodes: List[EntityNode], entity_nodes: list[EntityNode],
community_node: CommunityNode, community_node: CommunityNode,
created_at: datetime, created_at: datetime,
) -> List[CommunityEdge]: ) -> list[CommunityEdge]:
edges: List[CommunityEdge] = [ edges: list[CommunityEdge] = [
CommunityEdge( CommunityEdge(
source_node_uuid=community_node.uuid, source_node_uuid=community_node.uuid,
target_node_uuid=node.uuid, target_node_uuid=node.uuid,

View file

@ -17,7 +17,6 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
from typing import List
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
@ -31,7 +30,7 @@ async def extract_edge_dates(
llm_client: LLMClient, llm_client: LLMClient,
edge: EntityEdge, edge: EntityEdge,
current_episode: EpisodicNode, current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode], previous_episodes: list[EpisodicNode],
) -> tuple[datetime | None, datetime | None]: ) -> tuple[datetime | None, datetime | None]:
context = { context = {
'edge_fact': edge.fact, 'edge_fact': edge.fact,