update bulk interfae handling

This commit is contained in:
prestonrasmussen 2025-10-26 22:03:16 -04:00
parent 6338378614
commit a42437a856
5 changed files with 23 additions and 17 deletions

View file

@ -14,8 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import datetime
import asyncio import asyncio
import datetime
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -236,10 +236,10 @@ class FalkorDriver(GraphDriver):
async def health_check(self) -> None: async def health_check(self) -> None:
"""Check FalkorDB connectivity by running a simple query.""" """Check FalkorDB connectivity by running a simple query."""
try: try:
await self.execute_query("MATCH (n) RETURN 1 LIMIT 1") await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
return None return None
except Exception as e: except Exception as e:
print(f"FalkorDB health check failed: {e}") print(f'FalkorDB health check failed: {e}')
raise raise
@staticmethod @staticmethod

View file

@ -79,5 +79,5 @@ class Neo4jDriver(GraphDriver):
await self.client.verify_connectivity() await self.client.verify_connectivity()
return None return None
except Exception as e: except Exception as e:
print(f"Neo4j health check failed: {e}") print(f'Neo4j health check failed: {e}')
raise raise

View file

@ -214,12 +214,10 @@ async def add_nodes_and_edges_bulk_tx(
edges.append(edge_data) edges.append(edge_data)
if driver.graph_operations_interface: if driver.graph_operations_interface:
await driver.graph_operations_interface.episodic_node_save_bulk( await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
None, driver, tx, episodic_nodes
)
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes) await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
await driver.graph_operations_interface.episodic_edge_save_bulk( await driver.graph_operations_interface.episodic_edge_save_bulk(
None, driver, tx, episodic_edges None, driver, tx, [edge.model_dump() for edge in episodic_edges]
) )
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges) await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)

View file

@ -467,6 +467,7 @@ class Neo4jConfig(BaseModel):
password=os.environ.get('NEO4J_PASSWORD', 'password'), password=os.environ.get('NEO4J_PASSWORD', 'password'),
) )
class FalkorConfig(BaseModel): class FalkorConfig(BaseModel):
"""Configuration for FalkorDB database connection.""" """Configuration for FalkorDB database connection."""
@ -483,6 +484,7 @@ class FalkorConfig(BaseModel):
password = os.environ.get('FALKORDB_PASSWORD', '') password = os.environ.get('FALKORDB_PASSWORD', '')
return cls(host=host, port=port, user=user, password=password) return cls(host=host, port=port, user=user, password=password)
class GraphitiConfig(BaseModel): class GraphitiConfig(BaseModel):
"""Configuration for Graphiti client. """Configuration for Graphiti client.
@ -504,7 +506,9 @@ class GraphitiConfig(BaseModel):
"""Create a configuration instance from environment variables.""" """Create a configuration instance from environment variables."""
db_type = os.environ.get('DATABASE_TYPE') db_type = os.environ.get('DATABASE_TYPE')
if not db_type: if not db_type:
raise ValueError('DATABASE_TYPE environment variable must be set (e.g., "neo4j" or "falkordb")') raise ValueError(
'DATABASE_TYPE environment variable must be set (e.g., "neo4j" or "falkordb")'
)
if db_type == 'neo4j': if db_type == 'neo4j':
return cls( return cls(
llm=GraphitiLLMConfig.from_env(), llm=GraphitiLLMConfig.from_env(),
@ -622,7 +626,9 @@ async def initialize_graphiti():
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set') raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
# Validate FalkorDB configuration # Validate FalkorDB configuration
if config.database_type == 'falkordb' and (not config.falkordb.host or not config.falkordb.port): if config.database_type == 'falkordb' and (
not config.falkordb.host or not config.falkordb.port
):
raise ValueError('FALKORDB_HOST and FALKORDB_PORT must be set for FalkorDB') raise ValueError('FALKORDB_HOST and FALKORDB_PORT must be set for FalkorDB')
embedder_client = config.embedder.create_client() embedder_client = config.embedder.create_client()
@ -637,6 +643,7 @@ async def initialize_graphiti():
) )
elif config.database_type == 'falkordb': elif config.database_type == 'falkordb':
from graphiti_core.driver.falkordb_driver import FalkorDriver from graphiti_core.driver.falkordb_driver import FalkorDriver
host = config.falkordb.host if hasattr(config.falkordb, 'host') else 'localhost' host = config.falkordb.host if hasattr(config.falkordb, 'host') else 'localhost'
port = int(config.falkordb.port) if hasattr(config.falkordb, 'port') else 6379 port = int(config.falkordb.port) if hasattr(config.falkordb, 'port') else 6379
username = config.falkordb.user or None username = config.falkordb.user or None
@ -1205,10 +1212,11 @@ async def get_status() -> StatusResponse:
client = cast(Graphiti, graphiti_client) client = cast(Graphiti, graphiti_client)
# Test database connection # Test database connection
await client.driver.health_check() # type: ignore # type: ignore await client.driver.health_check() # type: ignore # type: ignore
return StatusResponse( return StatusResponse(
status='ok', message=f'Graphiti MCP server is running and connected to {config.database_type}' status='ok',
message=f'Graphiti MCP server is running and connected to {config.database_type}',
) )
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)

4
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1 version = 1
revision = 3 revision = 2
requires-python = ">=3.10, <4" requires-python = ">=3.10, <4"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.14'", "python_full_version >= '3.14'",
@ -783,7 +783,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.22.0rc5" version = "0.22.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },