Make graph label methods required in BaseGraphStorage interface

• Remove fallback compatibility code
• Add get_popular_labels to ABC
• Add search_labels to ABC
• Enforce consistent implementation
• Clean up error handling paths
This commit is contained in:
yangdx 2025-09-20 12:40:36 +08:00
parent 3296bcb553
commit 26c9ba4cb5
2 changed files with 25 additions and 31 deletions

View file

@ -61,16 +61,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of popular labels sorted by degree (highest first)
"""
try:
# Check if the storage has the get_popular_labels method
if hasattr(rag.chunk_entity_relation_graph, "get_popular_labels"):
return await rag.chunk_entity_relation_graph.get_popular_labels(limit)
else:
# Fallback to get_graph_labels for compatibility
logger.warning(
"Storage doesn't support get_popular_labels, falling back to get_graph_labels"
)
all_labels = await rag.get_graph_labels()
return all_labels[:limit]
return await rag.chunk_entity_relation_graph.get_popular_labels(limit)
except Exception as e:
logger.error(f"Error getting popular labels: {str(e)}")
logger.error(traceback.format_exc())
@ -96,27 +87,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
List[str]: List of matching labels sorted by relevance
"""
try:
# Check if the storage has the search_labels method
if hasattr(rag.chunk_entity_relation_graph, "search_labels"):
return await rag.chunk_entity_relation_graph.search_labels(q, limit)
else:
# Fallback to client-side filtering for compatibility
logger.warning(
"Storage doesn't support search_labels, falling back to client-side filtering"
)
all_labels = await rag.get_graph_labels()
query_lower = q.lower().strip()
if not query_lower:
return []
# Simple client-side filtering
matches = []
for label in all_labels:
if query_lower in label.lower():
matches.append(label)
return matches[:limit]
return await rag.chunk_entity_relation_graph.search_labels(q, limit)
except Exception as e:
logger.error(f"Error searching labels with query '{q}': {str(e)}")
logger.error(traceback.format_exc())

View file

@ -671,6 +671,29 @@ class BaseGraphStorage(StorageNameSpace, ABC):
A list of all edges, where each edge is a dictionary of its properties
"""
@abstractmethod
async def get_popular_labels(self, limit: int = 300) -> list[str]:
"""Get popular labels by node degree (most connected entities)
Args:
limit: Maximum number of labels to return
Returns:
List of labels sorted by degree (highest first)
"""
@abstractmethod
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
"""Search labels with fuzzy matching
Args:
query: Search query string
limit: Maximum number of results to return
Returns:
List of matching labels sorted by relevance
"""
class DocStatus(str, Enum):
"""Document processing status"""