reorganized test and formatting
This commit is contained in:
parent
1761ae3ea4
commit
c8af0b269c
4 changed files with 12 additions and 5 deletions
2
.github/workflows/test_memgraph.yml
vendored
2
.github/workflows/test_memgraph.yml
vendored
|
|
@ -56,4 +56,4 @@ jobs:
|
|||
GRAPH_DATABASE_URL: "bolt://localhost:7687"
|
||||
GRAPH_DATABASE_PASSWORD: "memgraph"
|
||||
GRAPH_DATABASE_USERNAME: "memgraph"
|
||||
run: poetry run python ./cognee/tests/test_memgraph.py
|
||||
run: poetry run python ./cognee/tests/unit/infrastructure/databases/vector/test_memgraph.py
|
||||
|
|
|
|||
|
|
@ -36,13 +36,17 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
|
|||
config = get_llm_config()
|
||||
|
||||
# Use BAML's SummarizeContent function
|
||||
summary_result = await b.SummarizeContent(content, baml_options={"client_registry": config.baml_registry})
|
||||
summary_result = await b.SummarizeContent(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
|
||||
# Convert BAML result to the expected response model
|
||||
if response_model is SummarizedCode:
|
||||
# If it's asking for SummarizedCode but we got SummarizedContent,
|
||||
# we need to use SummarizeCode instead
|
||||
code_result = await b.SummarizeCode(content, baml_options={"client_registry": config.baml_registry})
|
||||
code_result = await b.SummarizeCode(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
return code_result
|
||||
else:
|
||||
# For other models, return the summary result
|
||||
|
|
@ -70,7 +74,9 @@ async def extract_code_summary(content: str):
|
|||
else:
|
||||
try:
|
||||
config = get_llm_config()
|
||||
result = await b.SummarizeCode(content, baml_options={"client_registry": config.baml_registry})
|
||||
result = await b.SummarizeCode(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
# Vector database tests module
|
||||
|
|
@ -32,7 +32,7 @@ async def main():
|
|||
dataset_name = "cs_explanations"
|
||||
|
||||
explanation_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
pathlib.Path(__file__).parent, "../../../../../test_data/Natural_language_processing.txt"
|
||||
)
|
||||
await cognee.add([explanation_file_path], dataset_name)
|
||||
|
||||
Loading…
Add table
Reference in a new issue