update bulk interfae handling
This commit is contained in:
parent
6338378614
commit
a42437a856
5 changed files with 23 additions and 17 deletions
|
|
@ -14,8 +14,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
|
@ -231,17 +231,17 @@ class FalkorDriver(GraphDriver):
|
||||||
"""
|
"""
|
||||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
|
|
||||||
async def health_check(self) -> None:
|
async def health_check(self) -> None:
|
||||||
"""Check FalkorDB connectivity by running a simple query."""
|
"""Check FalkorDB connectivity by running a simple query."""
|
||||||
try:
|
try:
|
||||||
await self.execute_query("MATCH (n) RETURN 1 LIMIT 1")
|
await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"FalkorDB health check failed: {e}")
|
print(f'FalkorDB health check failed: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_datetimes_to_strings(obj):
|
def convert_datetimes_to_strings(obj):
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
|
|
|
||||||
|
|
@ -72,12 +72,12 @@ class Neo4jDriver(GraphDriver):
|
||||||
return self.client.execute_query(
|
return self.client.execute_query(
|
||||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
)
|
)
|
||||||
|
|
||||||
async def health_check(self) -> None:
|
async def health_check(self) -> None:
|
||||||
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
||||||
try:
|
try:
|
||||||
await self.client.verify_connectivity()
|
await self.client.verify_connectivity()
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Neo4j health check failed: {e}")
|
print(f'Neo4j health check failed: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -214,12 +214,10 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
edges.append(edge_data)
|
edges.append(edge_data)
|
||||||
|
|
||||||
if driver.graph_operations_interface:
|
if driver.graph_operations_interface:
|
||||||
await driver.graph_operations_interface.episodic_node_save_bulk(
|
await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
|
||||||
None, driver, tx, episodic_nodes
|
|
||||||
)
|
|
||||||
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
|
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
|
||||||
await driver.graph_operations_interface.episodic_edge_save_bulk(
|
await driver.graph_operations_interface.episodic_edge_save_bulk(
|
||||||
None, driver, tx, episodic_edges
|
None, driver, tx, [edge.model_dump() for edge in episodic_edges]
|
||||||
)
|
)
|
||||||
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
|
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -467,6 +467,7 @@ class Neo4jConfig(BaseModel):
|
||||||
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FalkorConfig(BaseModel):
|
class FalkorConfig(BaseModel):
|
||||||
"""Configuration for FalkorDB database connection."""
|
"""Configuration for FalkorDB database connection."""
|
||||||
|
|
||||||
|
|
@ -483,6 +484,7 @@ class FalkorConfig(BaseModel):
|
||||||
password = os.environ.get('FALKORDB_PASSWORD', '')
|
password = os.environ.get('FALKORDB_PASSWORD', '')
|
||||||
return cls(host=host, port=port, user=user, password=password)
|
return cls(host=host, port=port, user=user, password=password)
|
||||||
|
|
||||||
|
|
||||||
class GraphitiConfig(BaseModel):
|
class GraphitiConfig(BaseModel):
|
||||||
"""Configuration for Graphiti client.
|
"""Configuration for Graphiti client.
|
||||||
|
|
||||||
|
|
@ -504,7 +506,9 @@ class GraphitiConfig(BaseModel):
|
||||||
"""Create a configuration instance from environment variables."""
|
"""Create a configuration instance from environment variables."""
|
||||||
db_type = os.environ.get('DATABASE_TYPE')
|
db_type = os.environ.get('DATABASE_TYPE')
|
||||||
if not db_type:
|
if not db_type:
|
||||||
raise ValueError('DATABASE_TYPE environment variable must be set (e.g., "neo4j" or "falkordb")')
|
raise ValueError(
|
||||||
|
'DATABASE_TYPE environment variable must be set (e.g., "neo4j" or "falkordb")'
|
||||||
|
)
|
||||||
if db_type == 'neo4j':
|
if db_type == 'neo4j':
|
||||||
return cls(
|
return cls(
|
||||||
llm=GraphitiLLMConfig.from_env(),
|
llm=GraphitiLLMConfig.from_env(),
|
||||||
|
|
@ -622,7 +626,9 @@ async def initialize_graphiti():
|
||||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||||
|
|
||||||
# Validate FalkorDB configuration
|
# Validate FalkorDB configuration
|
||||||
if config.database_type == 'falkordb' and (not config.falkordb.host or not config.falkordb.port):
|
if config.database_type == 'falkordb' and (
|
||||||
|
not config.falkordb.host or not config.falkordb.port
|
||||||
|
):
|
||||||
raise ValueError('FALKORDB_HOST and FALKORDB_PORT must be set for FalkorDB')
|
raise ValueError('FALKORDB_HOST and FALKORDB_PORT must be set for FalkorDB')
|
||||||
|
|
||||||
embedder_client = config.embedder.create_client()
|
embedder_client = config.embedder.create_client()
|
||||||
|
|
@ -637,6 +643,7 @@ async def initialize_graphiti():
|
||||||
)
|
)
|
||||||
elif config.database_type == 'falkordb':
|
elif config.database_type == 'falkordb':
|
||||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||||
|
|
||||||
host = config.falkordb.host if hasattr(config.falkordb, 'host') else 'localhost'
|
host = config.falkordb.host if hasattr(config.falkordb, 'host') else 'localhost'
|
||||||
port = int(config.falkordb.port) if hasattr(config.falkordb, 'port') else 6379
|
port = int(config.falkordb.port) if hasattr(config.falkordb, 'port') else 6379
|
||||||
username = config.falkordb.user or None
|
username = config.falkordb.user or None
|
||||||
|
|
@ -1205,10 +1212,11 @@ async def get_status() -> StatusResponse:
|
||||||
client = cast(Graphiti, graphiti_client)
|
client = cast(Graphiti, graphiti_client)
|
||||||
|
|
||||||
# Test database connection
|
# Test database connection
|
||||||
await client.driver.health_check() # type: ignore # type: ignore
|
await client.driver.health_check() # type: ignore # type: ignore
|
||||||
|
|
||||||
return StatusResponse(
|
return StatusResponse(
|
||||||
status='ok', message=f'Graphiti MCP server is running and connected to {config.database_type}'
|
status='ok',
|
||||||
|
message=f'Graphiti MCP server is running and connected to {config.database_type}',
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
|
|
|
||||||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 3
|
revision = 2
|
||||||
requires-python = ">=3.10, <4"
|
requires-python = ">=3.10, <4"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.14'",
|
"python_full_version >= '3.14'",
|
||||||
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.22.0rc5"
|
version = "0.22.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue