Format entire codebase with ruff and add type hints across all modules: - Apply ruff formatting to all Python files (121 files, 17K insertions) - Add type hints to function signatures throughout lightrag core and API - Update test suite with improved type annotations and docstrings - Add pyrightconfig.json for static type checking configuration - Create prompt_optimized.py and test_extraction_prompt_ab.py test files - Update ruff.toml and .gitignore for improved linting configuration - Standardize code style across examples, reproduce scripts, and utilities
183 lines
6 KiB
Python
183 lines
6 KiB
Python
import json
|
|
import os
|
|
import xml.etree.ElementTree as ET
|
|
|
|
from neo4j import GraphDatabase
|
|
|
|
# Constants
|
|
WORKING_DIR = './dickens'
|
|
BATCH_SIZE_NODES = 500
|
|
BATCH_SIZE_EDGES = 100
|
|
|
|
# Neo4j connection credentials
|
|
NEO4J_URI = 'bolt://localhost:7687'
|
|
NEO4J_USERNAME = 'neo4j'
|
|
NEO4J_PASSWORD = 'your_password'
|
|
|
|
|
|
def xml_to_json(xml_file):
|
|
try:
|
|
tree = ET.parse(xml_file)
|
|
root = tree.getroot()
|
|
|
|
# Print the root element's tag and attributes to confirm the file has been correctly loaded
|
|
print(f'Root element: {root.tag}')
|
|
print(f'Root attributes: {root.attrib}')
|
|
|
|
data = {'nodes': [], 'edges': []}
|
|
|
|
# Use namespace
|
|
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
|
|
|
|
for node in root.findall('.//node', namespace):
|
|
node_data = {
|
|
'id': node.get('id').strip('"'),
|
|
'entity_type': node.find("./data[@key='d1']", namespace).text.strip('"')
|
|
if node.find("./data[@key='d1']", namespace) is not None
|
|
else '',
|
|
'description': node.find("./data[@key='d2']", namespace).text
|
|
if node.find("./data[@key='d2']", namespace) is not None
|
|
else '',
|
|
'source_id': node.find("./data[@key='d3']", namespace).text
|
|
if node.find("./data[@key='d3']", namespace) is not None
|
|
else '',
|
|
}
|
|
data['nodes'].append(node_data)
|
|
|
|
for edge in root.findall('.//edge', namespace):
|
|
edge_data = {
|
|
'source': edge.get('source').strip('"'),
|
|
'target': edge.get('target').strip('"'),
|
|
'weight': float(edge.find("./data[@key='d5']", namespace).text)
|
|
if edge.find("./data[@key='d5']", namespace) is not None
|
|
else 0.0,
|
|
'description': edge.find("./data[@key='d6']", namespace).text
|
|
if edge.find("./data[@key='d6']", namespace) is not None
|
|
else '',
|
|
'keywords': edge.find("./data[@key='d9']", namespace).text
|
|
if edge.find("./data[@key='d9']", namespace) is not None
|
|
else '',
|
|
'source_id': edge.find("./data[@key='d8']", namespace).text
|
|
if edge.find("./data[@key='d8']", namespace) is not None
|
|
else '',
|
|
}
|
|
data['edges'].append(edge_data)
|
|
|
|
# Print the number of nodes and edges found
|
|
print(f'Found {len(data["nodes"])} nodes and {len(data["edges"])} edges')
|
|
|
|
return data
|
|
except ET.ParseError as e:
|
|
print(f'Error parsing XML file: {e}')
|
|
return None
|
|
except Exception as e:
|
|
print(f'An error occurred: {e}')
|
|
return None
|
|
|
|
|
|
def convert_xml_to_json(xml_path, output_path):
|
|
"""Converts XML file to JSON and saves the output."""
|
|
if not os.path.exists(xml_path):
|
|
print(f'Error: File not found - {xml_path}')
|
|
return None
|
|
|
|
json_data = xml_to_json(xml_path)
|
|
if json_data:
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
print(f'JSON file created: {output_path}')
|
|
return json_data
|
|
else:
|
|
print('Failed to create JSON data')
|
|
return None
|
|
|
|
|
|
def process_in_batches(tx, query, data, batch_size):
|
|
"""Process data in batches and execute the given query."""
|
|
for i in range(0, len(data), batch_size):
|
|
batch = data[i : i + batch_size]
|
|
tx.run(query, {'nodes': batch} if 'nodes' in query else {'edges': batch})
|
|
|
|
|
|
def main():
|
|
# Paths
|
|
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
|
|
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
|
|
|
|
# Convert XML to JSON
|
|
json_data = convert_xml_to_json(xml_file, json_file)
|
|
if json_data is None:
|
|
return
|
|
|
|
# Load nodes and edges
|
|
nodes = json_data.get('nodes', [])
|
|
edges = json_data.get('edges', [])
|
|
|
|
# Neo4j queries
|
|
create_nodes_query = """
|
|
UNWIND $nodes AS node
|
|
MERGE (e:Entity {id: node.id})
|
|
SET e.entity_type = node.entity_type,
|
|
e.description = node.description,
|
|
e.source_id = node.source_id,
|
|
e.displayName = node.id
|
|
REMOVE e:Entity
|
|
WITH e, node
|
|
CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode
|
|
RETURN count(*)
|
|
"""
|
|
|
|
create_edges_query = """
|
|
UNWIND $edges AS edge
|
|
MATCH (source {id: edge.source})
|
|
MATCH (target {id: edge.target})
|
|
WITH source, target, edge,
|
|
CASE
|
|
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
|
|
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
|
|
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
|
|
WHEN edge.keywords CONTAINS 'located' THEN 'located'
|
|
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
|
|
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
|
|
END AS relType
|
|
CALL apoc.create.relationship(source, relType, {
|
|
weight: edge.weight,
|
|
description: edge.description,
|
|
keywords: edge.keywords,
|
|
source_id: edge.source_id
|
|
}, target) YIELD rel
|
|
RETURN count(*)
|
|
"""
|
|
|
|
set_displayname_and_labels_query = """
|
|
MATCH (n)
|
|
SET n.displayName = n.id
|
|
WITH n
|
|
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
|
|
RETURN count(*)
|
|
"""
|
|
|
|
# Create a Neo4j driver
|
|
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
|
|
|
try:
|
|
# Execute queries in batches
|
|
with driver.session() as session:
|
|
# Insert nodes in batches
|
|
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
|
|
|
|
# Insert edges in batches
|
|
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
|
|
|
|
# Set displayName and labels
|
|
session.run(set_displayname_and_labels_query)
|
|
|
|
except Exception as e:
|
|
print(f'Error occurred: {e}')
|
|
|
|
finally:
|
|
driver.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|