fix: Fixes collection search limit in brute force triplet search (#814)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
34b95b687c
commit
a78fec3a91
4 changed files with 39 additions and 36 deletions
8
.github/workflows/test_memgraph.yml
vendored
8
.github/workflows/test_memgraph.yml
vendored
|
|
@ -1,9 +1,9 @@
|
||||||
name: test | memgraph
|
name: test | memgraph
|
||||||
|
|
||||||
on:
|
# on:
|
||||||
workflow_dispatch:
|
# workflow_dispatch:
|
||||||
pull_request:
|
# pull_request:
|
||||||
types: [labeled, synchronize]
|
# types: [labeled, synchronize]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
|
||||||
logger = get_logger("MemgraphAdapter", level=ERROR)
|
logger = get_logger("MemgraphAdapter", level=ERROR)
|
||||||
|
|
||||||
|
|
||||||
class MemgraphAdapter(GraphDBInterface):
|
class MemgraphAdapter(GraphDBInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -536,7 +537,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
async def get_node_labels_string(self):
|
async def get_node_labels_string(self):
|
||||||
node_labels_query = f"""
|
node_labels_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WITH DISTINCT labels(n) AS labelList
|
WITH DISTINCT labels(n) AS labelList
|
||||||
UNWIND labelList AS label
|
UNWIND labelList AS label
|
||||||
|
|
@ -552,7 +553,9 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
return node_labels_str
|
return node_labels_str
|
||||||
|
|
||||||
async def get_relationship_labels_string(self):
|
async def get_relationship_labels_string(self):
|
||||||
relationship_types_query = "MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;"
|
relationship_types_query = (
|
||||||
|
"MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;"
|
||||||
|
)
|
||||||
relationship_types_result = await self.query(relationship_types_query)
|
relationship_types_result = await self.query(relationship_types_query)
|
||||||
relationship_types = (
|
relationship_types = (
|
||||||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ async def brute_force_search(
|
||||||
async def search_in_collection(collection_name: str):
|
async def search_in_collection(collection_name: str):
|
||||||
try:
|
try:
|
||||||
return await vector_engine.search(
|
return await vector_engine.search(
|
||||||
collection_name=collection_name, query_text=query, limit=top_k
|
collection_name=collection_name, query_text=query, limit=0
|
||||||
)
|
)
|
||||||
except CollectionNotFoundError:
|
except CollectionNotFoundError:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue