Fix linter issues
This commit is contained in:
parent
5b115594b7
commit
76a0aa7e8b
24 changed files with 57 additions and 78 deletions
|
|
@ -6,6 +6,7 @@ from sqlalchemy.engine import Connection
|
|||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
from alembic import context
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
@ -83,8 +84,6 @@ def run_migrations_online() -> None:
|
|||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
if db_engine.engine.dialect.name == "sqlite":
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def request_validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
if request.url.path == "/api/v1/auth/login":
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ async def cognify(
|
|||
# If no datasets are provided, cognify all existing datasets.
|
||||
datasets = existing_datasets
|
||||
|
||||
if type(datasets[0]) == str:
|
||||
if isinstance(datasets[0], str):
|
||||
datasets = await get_datasets_by_name(datasets, user.id)
|
||||
|
||||
existing_datasets_map = {
|
||||
|
|
|
|||
|
|
@ -107,7 +107,8 @@ def get_datasets_router() -> APIRouter:
|
|||
status_code=200,
|
||||
content=str(graph_url),
|
||||
)
|
||||
except:
|
||||
except Exception as error:
|
||||
print(error)
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content="Graphistry credentials are not set. Please set them in your .env file.",
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ async def search(search_type: str, params: Dict[str, Any], user: User = None) ->
|
|||
|
||||
for search_result in search_results:
|
||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||
document_id = UUID(document_id) if type(document_id) == str else document_id
|
||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
||||
|
||||
if document_id is None or document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ async def search(
|
|||
|
||||
for search_result in search_results:
|
||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||
document_id = UUID(document_id) if type(document_id) == str else document_id
|
||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
||||
|
||||
if document_id is None or document_id in own_document_ids:
|
||||
filtered_search_results.append(search_result)
|
||||
|
|
|
|||
|
|
@ -266,7 +266,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
for node in graph_data["nodes"]:
|
||||
try:
|
||||
node["id"] = UUID(node["id"])
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
if "updated_at" in node:
|
||||
node["updated_at"] = datetime.strptime(
|
||||
|
|
@ -282,7 +283,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
edge["target"] = target_id
|
||||
edge["source_node_id"] = source_id
|
||||
edge["target_node_id"] = target_id
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
if "updated_at" in edge:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
import json
|
||||
from textwrap import dedent
|
||||
from uuid import UUID
|
||||
from webbrowser import Error
|
||||
|
||||
from falkordb import FalkorDB
|
||||
|
||||
|
|
@ -167,7 +168,8 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
for index in indices.result_set
|
||||
]
|
||||
)
|
||||
except:
|
||||
except Error as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
async def index_data_points(
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class MilvusAdapter(VectorDBInterface):
|
|||
|
||||
self.embedding_engine = embedding_engine
|
||||
|
||||
def get_milvus_client(self) -> "MilvusClient":
|
||||
def get_milvus_client(self):
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
if self.api_key:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
class PineconeVectorDB(VectorDB):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_pinecone(self.index_name)
|
||||
|
||||
def init_pinecone(self, index_name):
|
||||
# Pinecone initialization logic
|
||||
pass
|
||||
# class PineconeVectorDB(VectorDB):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self.init_pinecone(self.index_name)
|
||||
#
|
||||
# def init_pinecone(self, index_name):
|
||||
# # Pinecone initialization logic
|
||||
# pass
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ class OpenAIAdapter(LLMInterface):
|
|||
self.api_version = api_version
|
||||
self.streaming = streaming
|
||||
|
||||
|
||||
@observe(as_type="generation")
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ async def add_model_class_to_graph(
|
|||
|
||||
if hasattr(field_type, "model_fields"): # Check if field type is a Pydantic model
|
||||
await add_model_class_to_graph(field_type, graph, model_name, field_name)
|
||||
elif get_origin(field.annotation) == list:
|
||||
elif isinstance(get_origin(field.annotation), list):
|
||||
list_types = get_args(field_type)
|
||||
for item_type in list_types:
|
||||
await add_model_class_to_graph(item_type, graph, model_name, field_name)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from pydantic_core import PydanticUndefined
|
||||
from pydantic import create_model
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
|
@ -16,9 +17,6 @@ class JSONEncoder(json.JSONEncoder):
|
|||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
|
||||
def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list = []):
|
||||
fields = {
|
||||
name: (field.annotation, field.default if field.default is not None else PydanticUndefined)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,4 @@ async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
|||
yield SQLAlchemyUserDatabase(session, User)
|
||||
|
||||
|
||||
|
||||
|
||||
get_user_db_context = asynccontextmanager(get_user_db)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||
from fastapi import Depends, Request
|
||||
from fastapi_users import BaseUserManager, UUIDIDMixin, models
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from .get_user_db import get_user_db
|
||||
from .models import User
|
||||
from .methods import get_user
|
||||
|
|
@ -50,6 +50,4 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
|||
yield UserManager(user_db)
|
||||
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
get_user_manager_context = asynccontextmanager(get_user_manager)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
|||
from .Principal import Principal
|
||||
from .UserGroup import UserGroup
|
||||
from .Group import Group
|
||||
from fastapi_users import schemas
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Principal):
|
||||
|
|
@ -23,7 +24,6 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
|
|||
|
||||
|
||||
# Keep these schemas in sync with User model
|
||||
from fastapi_users import schemas
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid_UUID]):
|
||||
|
|
|
|||
|
|
@ -274,8 +274,6 @@ def extract_pos_tags(sentence):
|
|||
return pos_tags
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ from typing import Union
|
|||
|
||||
def get_data_from_llama_index(data_point: Union[Document, ImageDocument], dataset_name: str) -> str:
|
||||
# Specific type checking is used to ensure it's not a child class from Document
|
||||
if type(data_point) == Document:
|
||||
if isinstance(data_point, Document) and type(data_point) is Document:
|
||||
file_path = data_point.metadata.get("file_path")
|
||||
if file_path is None:
|
||||
file_path = save_data_to_file(data_point.text)
|
||||
return file_path
|
||||
return file_path
|
||||
elif type(data_point) == ImageDocument:
|
||||
elif isinstance(data_point, ImageDocument) and type(data_point) is ImageDocument:
|
||||
if data_point.image_path is None:
|
||||
file_path = save_data_to_file(data_point.text)
|
||||
return file_path
|
||||
|
|
|
|||
|
|
@ -35,6 +35,4 @@ def test_chunk_by_word_splits(input_text):
|
|||
chunks = np.array(list(chunk_by_word(input_text)))
|
||||
space_test = np.array([" " not in chunk[0].strip() for chunk in chunks])
|
||||
|
||||
assert np.all(
|
||||
space_test
|
||||
), f"These chunks contain spaces within them: {chunks[space_test == False]}"
|
||||
assert np.all(space_test), f"These chunks contain spaces within them: {chunks[~space_test]}"
|
||||
|
|
|
|||
|
|
@ -102,7 +102,6 @@ def test_prepare_nodes():
|
|||
assert len(nodes_df) == 1
|
||||
|
||||
|
||||
|
||||
def test_create_cognee_style_network_with_logo():
|
||||
import networkx as nx
|
||||
from unittest.mock import patch
|
||||
|
|
@ -117,29 +116,23 @@ def test_create_cognee_style_network_with_logo():
|
|||
# Convert the graph to a tuple format for serialization
|
||||
graph_tuple = graph_to_tuple(graph)
|
||||
|
||||
print(graph_tuple)
|
||||
original_open = open
|
||||
|
||||
# Define the output filename
|
||||
output_filename = "test_network.html"
|
||||
def mock_open_read_side_effect(*args, **kwargs):
|
||||
if "cognee-logo.png" in args[0]:
|
||||
return BytesIO(b"mock_png_data")
|
||||
return original_open(*args, **kwargs)
|
||||
|
||||
with patch("bokeh.plotting.from_networkx") as mock_from_networkx:
|
||||
original_open = open
|
||||
with patch("builtins.open", side_effect=mock_open_read_side_effect):
|
||||
result = create_cognee_style_network_with_logo(
|
||||
graph_tuple,
|
||||
title="Test Network",
|
||||
node_attribute="group",
|
||||
layout_func=nx.spring_layout,
|
||||
layout_scale=3.0,
|
||||
logo_alpha=0.5,
|
||||
)
|
||||
|
||||
def mock_open_read_side_effect(*args, **kwargs):
|
||||
if "cognee-logo.png" in args[0]:
|
||||
return BytesIO(b"mock_png_data")
|
||||
return original_open(*args, **kwargs)
|
||||
|
||||
with patch("builtins.open", side_effect=mock_open_read_side_effect):
|
||||
result = create_cognee_style_network_with_logo(
|
||||
graph_tuple,
|
||||
title="Test Network",
|
||||
node_attribute="group",
|
||||
layout_func=nx.spring_layout,
|
||||
layout_scale=3.0,
|
||||
logo_alpha=0.5,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ from deepeval.synthesizer import Synthesizer
|
|||
import dotenv
|
||||
from deepeval.test_case import LLMTestCase
|
||||
|
||||
# import pytest
|
||||
# from deepeval import assert_test
|
||||
from deepeval.metrics import AnswerRelevancyMetric
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# synthesizer = Synthesizer()
|
||||
|
|
@ -33,11 +37,6 @@ print(dataset.goldens)
|
|||
print(dataset)
|
||||
|
||||
|
||||
# import pytest
|
||||
# from deepeval import assert_test
|
||||
from deepeval.metrics import AnswerRelevancyMetric
|
||||
|
||||
|
||||
answer_relevancy_metric = AnswerRelevancyMetric(threshold=0.5)
|
||||
|
||||
# from deepeval import evaluate
|
||||
|
|
|
|||
|
|
@ -44,8 +44,6 @@ print(dataset.goldens)
|
|||
print(dataset)
|
||||
|
||||
|
||||
|
||||
|
||||
class AnswerModel(BaseModel):
|
||||
response: str
|
||||
|
||||
|
|
@ -76,10 +74,7 @@ async def run_cognify_base_rag():
|
|||
await add("data://test_datasets", "initial_test")
|
||||
|
||||
graph = await cognify("initial_test")
|
||||
|
||||
pass
|
||||
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def cognify_search_base_rag(content: str, context: str):
|
||||
|
|
|
|||
|
|
@ -28,13 +28,10 @@ def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, An
|
|||
# Start memory tracking
|
||||
tracemalloc.start()
|
||||
|
||||
|
||||
# Measure execution time and CPU usage
|
||||
start_time = time.perf_counter()
|
||||
start_cpu_time = process.cpu_times()
|
||||
|
||||
|
||||
|
||||
end_cpu_time = process.cpu_times()
|
||||
end_time = time.perf_counter()
|
||||
|
||||
|
|
@ -45,8 +42,6 @@ def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, An
|
|||
)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
|
||||
|
||||
|
||||
# Store results
|
||||
execution_times.append(execution_time)
|
||||
peak_memory_usages.append(peak / 1024 / 1024) # Convert to MB
|
||||
|
|
|
|||
|
|
@ -119,7 +119,12 @@ line-length = 100
|
|||
exclude = [
|
||||
"migrations/", # Ignore migrations directory
|
||||
"notebooks/", # Ignore notebook files
|
||||
"build/", # Ignore build directory
|
||||
"build/", # Ignore build directory
|
||||
"cognee/pipelines.py",
|
||||
"cognee/modules/users/models/Group.py",
|
||||
"cognee/modules/users/models/ACL.py",
|
||||
"cognee/modules/pipelines/models/Task.py",
|
||||
"cognee/modules/data/models/Dataset.py"
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue