parent
83c7640d9c
commit
b728ff0f68
7 changed files with 181 additions and 55 deletions
|
|
@ -11,10 +11,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel, ABC):
|
class Edge(BaseModel, ABC):
|
||||||
uuid: Field(default_factory=lambda: uuid1().hex)
|
uuid: str = Field(default_factory=lambda: uuid1().hex)
|
||||||
source_node: Node
|
source_node: Node
|
||||||
target_node: Node
|
target_node: Node
|
||||||
transaction_from: datetime
|
created_at: datetime
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: AsyncDriver): ...
|
||||||
|
|
@ -25,14 +25,14 @@ class EpisodicEdge(Edge):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||||
MATCH (node:Semantic {uuid: $semantic_uuid})
|
MATCH (node:Entity {uuid: $entity_uuid})
|
||||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||||
SET r = {uuid: $uuid, transaction_from: $transaction_from}
|
SET r = {uuid: $uuid, created_at: $created_at}
|
||||||
RETURN r.uuid AS uuid""",
|
RETURN r.uuid AS uuid""",
|
||||||
episode_uuid=self.source_node.uuid,
|
episode_uuid=self.source_node.uuid,
|
||||||
semantic_uuid=self.target_node.uuid,
|
entity_uuid=self.target_node.uuid,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
transaction_from=self.transaction_from,
|
created_at=self.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Saved edge to neo4j: {self.uuid}")
|
logger.info(f"Saved edge to neo4j: {self.uuid}")
|
||||||
|
|
@ -44,14 +44,14 @@ class EpisodicEdge(Edge):
|
||||||
# Right now we have all edge nodes as type RELATES_TO
|
# Right now we have all edge nodes as type RELATES_TO
|
||||||
|
|
||||||
|
|
||||||
class SemanticEdge(Edge):
|
class EntityEdge(Edge):
|
||||||
name: str
|
name: str
|
||||||
fact: str
|
fact: str
|
||||||
fact_embedding: list[int] = None
|
fact_embedding: list[float] = None
|
||||||
episodes: list[str] = None # list of episodes that reference these semantic edges
|
episodes: list[str] = None # list of episode ids that reference these entity edges
|
||||||
transaction_to: datetime = None # datetime of when the node was invalidated
|
expired_at: datetime = None # datetime of when the node was invalidated
|
||||||
valid_from: datetime = None # datetime of when the fact became true
|
valid_at: datetime = None # datetime of when the fact became true
|
||||||
valid_to: datetime = None # datetime of when the fact stopped being true
|
invalid_at: datetime = None # datetime of when the fact stopped being true
|
||||||
|
|
||||||
def generate_embedding(self, embedder, model="text-embedding-3-large"):
|
def generate_embedding(self, embedder, model="text-embedding-3-large"):
|
||||||
text = self.fact.replace("\n", " ")
|
text = self.fact.replace("\n", " ")
|
||||||
|
|
@ -63,12 +63,12 @@ class SemanticEdge(Edge):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (source:Semantic {uuid: $source_uuid})
|
MATCH (source:Entity {uuid: $source_uuid})
|
||||||
MATCH (target:Semantic {uuid: $target_uuid})
|
MATCH (target:Entity {uuid: $target_uuid})
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||||
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
|
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
|
||||||
episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to,
|
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
||||||
valid_from: $valid_from, valid_to: $valid_to}
|
valid_at: $valid_at, invalid_at: $invalid_at}
|
||||||
RETURN r.uuid AS uuid""",
|
RETURN r.uuid AS uuid""",
|
||||||
source_uuid=self.source_node.uuid,
|
source_uuid=self.source_node.uuid,
|
||||||
target_uuid=self.target_node.uuid,
|
target_uuid=self.target_node.uuid,
|
||||||
|
|
@ -77,10 +77,10 @@ class SemanticEdge(Edge):
|
||||||
fact=self.fact,
|
fact=self.fact,
|
||||||
fact_embedding=self.fact_embedding,
|
fact_embedding=self.fact_embedding,
|
||||||
episodes=self.episodes,
|
episodes=self.episodes,
|
||||||
transaction_from=self.transaction_from,
|
created_at=self.created_at,
|
||||||
transaction_to=self.transaction_to,
|
expired_at=self.expired_at,
|
||||||
valid_from=self.valid_from,
|
valid_at=self.valid_at,
|
||||||
valid_to=self.valid_to,
|
invalid_at=self.invalid_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple, LiteralString
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
|
|
||||||
from core.nodes import SemanticNode, EpisodicNode, Node
|
from core.nodes import EntityNode, EpisodicNode, Node
|
||||||
from core.edges import SemanticEdge, Edge
|
from core.edges import EntityEdge, Edge
|
||||||
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
|
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -31,6 +31,9 @@ class Graphiti:
|
||||||
):
|
):
|
||||||
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
||||||
self.database = "neo4j"
|
self.database = "neo4j"
|
||||||
|
|
||||||
|
self.build_indices()
|
||||||
|
|
||||||
if llm_config:
|
if llm_config:
|
||||||
self.llm_config = llm_config
|
self.llm_config = llm_config
|
||||||
else:
|
else:
|
||||||
|
|
@ -39,6 +42,40 @@ class Graphiti:
|
||||||
def close(self):
|
def close(self):
|
||||||
self.driver.close()
|
self.driver.close()
|
||||||
|
|
||||||
|
async def build_indices(self):
|
||||||
|
index_queries: list[LiteralString] = [
|
||||||
|
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||||
|
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||||
|
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
||||||
|
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
||||||
|
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
|
||||||
|
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
|
||||||
|
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
|
||||||
|
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
|
||||||
|
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
|
||||||
|
]
|
||||||
|
# Add the range indices
|
||||||
|
for query in index_queries:
|
||||||
|
await self.driver.execute_query(query)
|
||||||
|
|
||||||
|
# Add the entity indices
|
||||||
|
await self.driver.execute_query(
|
||||||
|
"""
|
||||||
|
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.driver.execute_query(
|
||||||
|
"""
|
||||||
|
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||||
|
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
self, last_n: int, sources: list[str] | None = "messages"
|
self, last_n: int, sources: list[str] | None = "messages"
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
|
|
@ -48,8 +85,9 @@ class Graphiti:
|
||||||
# Utility function, to be removed from this class
|
# Utility function, to be removed from this class
|
||||||
async def clear_data(self): ...
|
async def clear_data(self): ...
|
||||||
|
|
||||||
async def search(self, query: str, config) -> (
|
async def search(
|
||||||
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
|
self, query: str, config
|
||||||
|
) -> (list)[Tuple[EntityNode, list[EntityEdge]]]:
|
||||||
(vec_nodes, vec_edges) = similarity_search(query, embedder)
|
(vec_nodes, vec_edges) = similarity_search(query, embedder)
|
||||||
(text_nodes, text_edges) = fulltext_search(query)
|
(text_nodes, text_edges) = fulltext_search(query)
|
||||||
|
|
||||||
|
|
@ -64,8 +102,9 @@ class Graphiti:
|
||||||
|
|
||||||
return [(node, edges)], episodes
|
return [(node, edges)], episodes
|
||||||
|
|
||||||
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> (
|
async def get_relevant_schema(
|
||||||
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
|
self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||||
|
) -> list[Tuple[EntityNode, list[EntityEdge]]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Call llm with the specified messages, and return the response
|
# Call llm with the specified messages, and return the response
|
||||||
|
|
@ -76,10 +115,10 @@ class Graphiti:
|
||||||
async def extract_new_edges(
|
async def extract_new_edges(
|
||||||
self,
|
self,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
new_nodes: list[SemanticNode],
|
new_nodes: list[EntityNode],
|
||||||
relevant_schema: dict[str, any],
|
relevant_schema: dict[str, any],
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> list[SemanticEdge]: ...
|
) -> list[EntityEdge]: ...
|
||||||
|
|
||||||
# Extract new nodes from the episode
|
# Extract new nodes from the episode
|
||||||
async def extract_new_nodes(
|
async def extract_new_nodes(
|
||||||
|
|
@ -87,14 +126,14 @@ class Graphiti:
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
relevant_schema: dict[str, any],
|
relevant_schema: dict[str, any],
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> list[SemanticNode]: ...
|
) -> list[EntityNode]: ...
|
||||||
|
|
||||||
# Invalidate edges that are no longer valid
|
# Invalidate edges that are no longer valid
|
||||||
async def invalidate_edges(
|
async def invalidate_edges(
|
||||||
self,
|
self,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
new_nodes: list[SemanticNode],
|
new_nodes: list[EntityNode],
|
||||||
new_edges: list[SemanticEdge],
|
new_edges: list[EntityEdge],
|
||||||
relevant_schema: dict[str, any],
|
relevant_schema: dict[str, any],
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
): ...
|
): ...
|
||||||
|
|
@ -137,7 +176,7 @@ class Graphiti:
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if isinstance(node, SemanticNode):
|
if isinstance(node, EntityNode):
|
||||||
await node.update_summary(self.driver)
|
await node.update_summary(self.driver)
|
||||||
if success_callback:
|
if success_callback:
|
||||||
await success_callback(episode)
|
await success_callback(episode)
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseModel, ABC):
|
class Node(BaseModel, ABC):
|
||||||
uuid: Field(default_factory=lambda: uuid1().hex)
|
uuid: str = Field(default_factory=lambda: uuid1().hex)
|
||||||
name: str
|
name: str
|
||||||
labels: list[str]
|
labels: list[str]
|
||||||
transaction_from: datetime
|
created_at: datetime
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: AsyncDriver): ...
|
||||||
|
|
@ -24,23 +24,23 @@ class EpisodicNode(Node):
|
||||||
source: str # source type
|
source: str # source type
|
||||||
source_description: str # description of the data source
|
source_description: str # description of the data source
|
||||||
content: str # raw episode data
|
content: str # raw episode data
|
||||||
semantic_edges: list[str] # list of semantic edges referenced in this episode
|
entity_edges: list[str] # list of entity edge ids referenced in this episode
|
||||||
valid_from: datetime = None # datetime of when the original document was created
|
valid_at: datetime = None # datetime of when the original document was created
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
|
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
|
||||||
semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from}
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||||
RETURN n.uuid AS uuid""",
|
RETURN n.uuid AS uuid""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
source_description=self.source_description,
|
source_description=self.source_description,
|
||||||
content=self.content,
|
content=self.content,
|
||||||
semantic_edges=self.semantic_edges,
|
entity_edges=self.entity_edges,
|
||||||
transaction_from=self.transaction_from,
|
created_at=self.created_at,
|
||||||
valid_from=self.valid_from,
|
valid_at=self.valid_at,
|
||||||
_database="neo4j",
|
_database="neo4j",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -50,7 +50,7 @@ class EpisodicNode(Node):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SemanticNode(Node):
|
class EntityNode(Node):
|
||||||
summary: str # regional summary of surrounding edges
|
summary: str # regional summary of surrounding edges
|
||||||
|
|
||||||
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
||||||
|
|
@ -58,13 +58,13 @@ class SemanticNode(Node):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MERGE (n:Semantic {uuid: $uuid})
|
MERGE (n:Entity {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from}
|
SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at}
|
||||||
RETURN n.uuid AS uuid""",
|
RETURN n.uuid AS uuid""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
summary=self.summary,
|
summary=self.summary,
|
||||||
transaction_from=self.transaction_from,
|
created_at=self.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from core.edges import EpisodicEdge, SemanticEdge, Edge
|
from core.edges import EpisodicEdge, EntityEdge, Edge
|
||||||
from core.nodes import SemanticNode, EpisodicNode, Node
|
from core.nodes import EntityNode, EpisodicNode, Node
|
||||||
|
|
||||||
|
|
||||||
async def bfs(
|
async def bfs(
|
||||||
nodes: list[Node], edges: list[Edge], k: int
|
nodes: list[Node], edges: list[Edge], k: int
|
||||||
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
|
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
|
||||||
|
|
||||||
|
|
||||||
# Breadth first search over nodes and edges with desired depth
|
# Breadth first search over nodes and edges with desired depth
|
||||||
|
|
@ -14,7 +14,7 @@ async def bfs(
|
||||||
|
|
||||||
async def similarity_search(
|
async def similarity_search(
|
||||||
query: str, embedder
|
query: str, embedder
|
||||||
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
|
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
|
||||||
|
|
||||||
|
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
|
|
@ -22,23 +22,23 @@ async def similarity_search(
|
||||||
|
|
||||||
async def fulltext_search(
|
async def fulltext_search(
|
||||||
query: str,
|
query: str,
|
||||||
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
|
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
|
||||||
|
|
||||||
|
|
||||||
# fulltext search over names and summary
|
# fulltext search over names and summary
|
||||||
|
|
||||||
|
|
||||||
def build_episodic_edges(
|
def build_episodic_edges(
|
||||||
semantic_nodes: list[SemanticNode], episode: EpisodicNode
|
entity_nodes: list[EntityNode], episode: EpisodicNode
|
||||||
) -> list[EpisodicEdge]:
|
) -> list[EpisodicEdge]:
|
||||||
edges: list[EpisodicEdge] = []
|
edges: list[EpisodicEdge] = []
|
||||||
|
|
||||||
for node in semantic_nodes:
|
for node in entity_nodes:
|
||||||
edges.append(
|
edges.append(
|
||||||
EpisodicEdge(
|
EpisodicEdge(
|
||||||
source_node=episode,
|
source_node=episode,
|
||||||
target_node=node,
|
target_node=node,
|
||||||
transaction_from=episode.transaction_from,
|
created_at=episode.created_at,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
5
poetry.lock
generated
5
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohappyeyeballs"
|
name = "aiohappyeyeballs"
|
||||||
|
|
@ -2186,6 +2186,7 @@ description = "Nvidia JIT LTO Library"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3"
|
python-versions = ">=3"
|
||||||
files = [
|
files = [
|
||||||
|
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb"},
|
||||||
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"},
|
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"},
|
||||||
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"},
|
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"},
|
||||||
]
|
]
|
||||||
|
|
@ -4937,4 +4938,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "cefc4469afc33f38b93547ee72ed623000f15faae3889d432a12ddcb33643848"
|
content-hash = "142d26cbdbf9c07019dfdb8599b70e8efb9c3842a3c95588d1f59b9c187e44ba"
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ python-dotenv = "^1.0.1"
|
||||||
pandas = "^2.2.2"
|
pandas = "^2.2.2"
|
||||||
pytest-asyncio = "^0.23.8"
|
pytest-asyncio = "^0.23.8"
|
||||||
pytest-xdist = "^3.6.1"
|
pytest-xdist = "^3.6.1"
|
||||||
|
pytest = "^8.3.2"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
||||||
85
tests/graphiti_int_tests.py
Normal file
85
tests/graphiti_int_tests.py
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from neo4j import AsyncGraphDatabase
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from core.edges import EpisodicEdge, EntityEdge
|
||||||
|
from core.graphiti import Graphiti
|
||||||
|
from core.nodes import EpisodicNode, EntityNode
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
pytest_plugins = ("pytest_asyncio",)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
NEO4J_URI = os.getenv("NEO4J_URI")
|
||||||
|
NEO4j_USER = os.getenv("NEO4J_USER")
|
||||||
|
NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graphiti_init():
|
||||||
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
|
||||||
|
await graphiti.build_indices()
|
||||||
|
graphiti.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_integration():
|
||||||
|
driver = AsyncGraphDatabase.driver(
|
||||||
|
NEO4J_URI,
|
||||||
|
auth=(NEO4j_USER, NEO4j_PASSWORD),
|
||||||
|
)
|
||||||
|
embedder = OpenAI().embeddings
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
episode = EpisodicNode(
|
||||||
|
name="test_episode",
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
source="message",
|
||||||
|
source_description="conversation message",
|
||||||
|
content="Alice likes Bob",
|
||||||
|
entity_edges=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
alice_node = EntityNode(
|
||||||
|
name="Alice",
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
summary="Alice summary",
|
||||||
|
)
|
||||||
|
|
||||||
|
bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary")
|
||||||
|
|
||||||
|
episodic_edge_1 = EpisodicEdge(
|
||||||
|
source_node=episode, target_node=alice_node, created_at=now
|
||||||
|
)
|
||||||
|
|
||||||
|
episodic_edge_2 = EpisodicEdge(
|
||||||
|
source_node=episode, target_node=bob_node, created_at=now
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_edge = EntityEdge(
|
||||||
|
source_node=alice_node,
|
||||||
|
target_node=bob_node,
|
||||||
|
created_at=now,
|
||||||
|
name="likes",
|
||||||
|
fact="Alice likes Bob",
|
||||||
|
episodes=[],
|
||||||
|
expired_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
invalid_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_edge.generate_embedding(embedder)
|
||||||
|
|
||||||
|
nodes = [episode, alice_node, bob_node]
|
||||||
|
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||||
|
|
||||||
|
await asyncio.gather(*[node.save(driver) for node in nodes])
|
||||||
|
await asyncio.gather(*[edge.save(driver) for edge in edges])
|
||||||
Loading…
Add table
Reference in a new issue