Make graph label methods required in BaseGraphStorage interface
• Remove fallback compatibility code • Add get_popular_labels to ABC • Add search_labels to ABC • Enforce consistent implementation • Clean up error handling paths
This commit is contained in:
parent
3296bcb553
commit
26c9ba4cb5
2 changed files with 25 additions and 31 deletions
|
|
@ -61,16 +61,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||
List[str]: List of popular labels sorted by degree (highest first)
|
||||
"""
|
||||
try:
|
||||
# Check if the storage has the get_popular_labels method
|
||||
if hasattr(rag.chunk_entity_relation_graph, "get_popular_labels"):
|
||||
return await rag.chunk_entity_relation_graph.get_popular_labels(limit)
|
||||
else:
|
||||
# Fallback to get_graph_labels for compatibility
|
||||
logger.warning(
|
||||
"Storage doesn't support get_popular_labels, falling back to get_graph_labels"
|
||||
)
|
||||
all_labels = await rag.get_graph_labels()
|
||||
return all_labels[:limit]
|
||||
return await rag.chunk_entity_relation_graph.get_popular_labels(limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting popular labels: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
@ -96,27 +87,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||
List[str]: List of matching labels sorted by relevance
|
||||
"""
|
||||
try:
|
||||
# Check if the storage has the search_labels method
|
||||
if hasattr(rag.chunk_entity_relation_graph, "search_labels"):
|
||||
return await rag.chunk_entity_relation_graph.search_labels(q, limit)
|
||||
else:
|
||||
# Fallback to client-side filtering for compatibility
|
||||
logger.warning(
|
||||
"Storage doesn't support search_labels, falling back to client-side filtering"
|
||||
)
|
||||
all_labels = await rag.get_graph_labels()
|
||||
query_lower = q.lower().strip()
|
||||
|
||||
if not query_lower:
|
||||
return []
|
||||
|
||||
# Simple client-side filtering
|
||||
matches = []
|
||||
for label in all_labels:
|
||||
if query_lower in label.lower():
|
||||
matches.append(label)
|
||||
|
||||
return matches[:limit]
|
||||
return await rag.chunk_entity_relation_graph.search_labels(q, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching labels with query '{q}': {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
|
|||
|
|
@ -671,6 +671,29 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||
A list of all edges, where each edge is a dictionary of its properties
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_popular_labels(self, limit: int = 300) -> list[str]:
|
||||
"""Get popular labels by node degree (most connected entities)
|
||||
|
||||
Args:
|
||||
limit: Maximum number of labels to return
|
||||
|
||||
Returns:
|
||||
List of labels sorted by degree (highest first)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||
"""Search labels with fuzzy matching
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching labels sorted by relevance
|
||||
"""
|
||||
|
||||
|
||||
class DocStatus(str, Enum):
|
||||
"""Document processing status"""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue