Merge pull request #378 from topoteretes/COG-748
feat: Add versioning to the data point model
This commit is contained in:
commit
b61dfd0948
18 changed files with 186 additions and 124 deletions
|
|
@ -62,10 +62,12 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
async def add_node(self, node: DataPoint):
|
async def add_node(self, node: DataPoint):
|
||||||
serialized_properties = self.serialize_properties(node.model_dump())
|
serialized_properties = self.serialize_properties(node.model_dump())
|
||||||
|
|
||||||
query = dedent("""MERGE (node {id: $node_id})
|
query = dedent(
|
||||||
|
"""MERGE (node {id: $node_id})
|
||||||
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
||||||
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
||||||
RETURN ID(node) AS internal_id, node.id AS nodeId""")
|
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||||
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"node_id": str(node.id),
|
"node_id": str(node.id),
|
||||||
|
|
@ -182,13 +184,15 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
):
|
):
|
||||||
serialized_properties = self.serialize_properties(edge_properties)
|
serialized_properties = self.serialize_properties(edge_properties)
|
||||||
|
|
||||||
query = dedent("""MATCH (from_node {id: $from_node}),
|
query = dedent(
|
||||||
|
"""MATCH (from_node {id: $from_node}),
|
||||||
(to_node {id: $to_node})
|
(to_node {id: $to_node})
|
||||||
MERGE (from_node)-[r]->(to_node)
|
MERGE (from_node)-[r]->(to_node)
|
||||||
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
||||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||||
RETURN r
|
RETURN r
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"from_node": str(from_node),
|
"from_node": str(from_node),
|
||||||
|
|
|
||||||
|
|
@ -88,23 +88,27 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return dedent(f"""
|
return dedent(
|
||||||
|
f"""
|
||||||
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||||
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||||
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||||
""").strip()
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
||||||
properties = await self.stringify_properties(edge[3])
|
properties = await self.stringify_properties(edge[3])
|
||||||
properties = f"{{{properties}}}"
|
properties = f"{{{properties}}}"
|
||||||
|
|
||||||
return dedent(f"""
|
return dedent(
|
||||||
|
f"""
|
||||||
MERGE (source {{id:'{edge[0]}'}})
|
MERGE (source {{id:'{edge[0]}'}})
|
||||||
MERGE (target {{id: '{edge[1]}'}})
|
MERGE (target {{id: '{edge[1]}'}})
|
||||||
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
|
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
|
||||||
ON MATCH SET edge.updated_at = timestamp()
|
ON MATCH SET edge.updated_at = timestamp()
|
||||||
ON CREATE SET edge.updated_at = timestamp()
|
ON CREATE SET edge.updated_at = timestamp()
|
||||||
""").strip()
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str):
|
async def create_collection(self, collection_name: str):
|
||||||
pass
|
pass
|
||||||
|
|
@ -195,12 +199,14 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
self.query(query)
|
self.query(query)
|
||||||
|
|
||||||
async def has_edges(self, edges):
|
async def has_edges(self, edges):
|
||||||
query = dedent("""
|
query = dedent(
|
||||||
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (a)-[r]->(b)
|
MATCH (a)-[r]->(b)
|
||||||
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||||
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||||
""").strip()
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"edges": [
|
"edges": [
|
||||||
|
|
@ -279,14 +285,16 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
|
|
||||||
[label, attribute_name] = collection_name.split(".")
|
[label, attribute_name] = collection_name.split(".")
|
||||||
|
|
||||||
query = dedent(f"""
|
query = dedent(
|
||||||
|
f"""
|
||||||
CALL db.idx.vector.queryNodes(
|
CALL db.idx.vector.queryNodes(
|
||||||
'{label}',
|
'{label}',
|
||||||
'{attribute_name}',
|
'{attribute_name}',
|
||||||
{limit},
|
{limit},
|
||||||
vecf32({query_vector})
|
vecf32({query_vector})
|
||||||
) YIELD node, score
|
) YIELD node, score
|
||||||
""").strip()
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
result = self.query(query)
|
result = self.query(query)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -93,10 +93,12 @@ class SQLAlchemyAdapter:
|
||||||
if self.engine.dialect.name == "postgresql":
|
if self.engine.dialect.name == "postgresql":
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
result = await connection.execute(
|
result = await connection.execute(
|
||||||
text("""
|
text(
|
||||||
|
"""
|
||||||
SELECT schema_name FROM information_schema.schemata
|
SELECT schema_name FROM information_schema.schemata
|
||||||
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
|
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return [schema[0] for schema in result.fetchall()]
|
return [schema[0] for schema in result.fetchall()]
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,34 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional, Any, Dict
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
# Define metadata type
|
||||||
class MetaData(TypedDict):
|
class MetaData(TypedDict):
|
||||||
index_fields: list[str]
|
index_fields: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
# Updated DataPoint model with versioning and new fields
|
||||||
class DataPoint(BaseModel):
|
class DataPoint(BaseModel):
|
||||||
__tablename__ = "data_point"
|
__tablename__ = "data_point"
|
||||||
id: UUID = Field(default_factory=uuid4)
|
id: UUID = Field(default_factory=uuid4)
|
||||||
updated_at: Optional[datetime] = datetime.now(timezone.utc)
|
created_at: int = Field(
|
||||||
|
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||||
|
)
|
||||||
|
updated_at: int = Field(
|
||||||
|
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||||
|
)
|
||||||
|
version: int = 1 # Default version
|
||||||
topological_rank: Optional[int] = 0
|
topological_rank: Optional[int] = 0
|
||||||
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
|
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
|
||||||
|
|
||||||
# class Config:
|
# Override the Pydantic configuration
|
||||||
# underscore_attrs_are_private = True
|
class Config:
|
||||||
|
underscore_attrs_are_private = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_embeddable_data(self, data_point):
|
def get_embeddable_data(self, data_point):
|
||||||
|
|
@ -31,11 +41,11 @@ class DataPoint(BaseModel):
|
||||||
|
|
||||||
if isinstance(attribute, str):
|
if isinstance(attribute, str):
|
||||||
return attribute.strip()
|
return attribute.strip()
|
||||||
else:
|
return attribute
|
||||||
return attribute
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_embeddable_properties(self, data_point):
|
def get_embeddable_properties(self, data_point):
|
||||||
|
"""Retrieve all embeddable properties."""
|
||||||
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
|
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
|
||||||
return [
|
return [
|
||||||
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
|
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
|
||||||
|
|
@ -45,4 +55,40 @@ class DataPoint(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_embeddable_property_names(self, data_point):
|
def get_embeddable_property_names(self, data_point):
|
||||||
|
"""Retrieve names of embeddable properties."""
|
||||||
return data_point._metadata["index_fields"] or []
|
return data_point._metadata["index_fields"] or []
|
||||||
|
|
||||||
|
def update_version(self):
|
||||||
|
"""Update the version and updated_at timestamp."""
|
||||||
|
self.version += 1
|
||||||
|
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||||
|
|
||||||
|
# JSON Serialization
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""Serialize the instance to a JSON string."""
|
||||||
|
return self.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(self, json_str: str):
|
||||||
|
"""Deserialize the instance from a JSON string."""
|
||||||
|
return self.model_validate_json(json_str)
|
||||||
|
|
||||||
|
# Pickle Serialization
|
||||||
|
def to_pickle(self) -> bytes:
|
||||||
|
"""Serialize the instance to pickle-compatible bytes."""
|
||||||
|
return pickle.dumps(self.dict())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pickle(self, pickled_data: bytes):
|
||||||
|
"""Deserialize the instance from pickled bytes."""
|
||||||
|
data = pickle.loads(pickled_data)
|
||||||
|
return self(**data)
|
||||||
|
|
||||||
|
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Serialize model to a dictionary."""
|
||||||
|
return self.model_dump(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
|
||||||
|
"""Deserialize model from a dictionary."""
|
||||||
|
return cls.model_validate(data)
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,11 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
||||||
raise RuntimeError("Initialization error") from e
|
raise RuntimeError("Initialization error") from e
|
||||||
|
|
||||||
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
|
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
|
||||||
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
|
await graph_engine.query(
|
||||||
|
"""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
|
||||||
r.target_node_id = target.id,
|
r.target_node_id = target.id,
|
||||||
r.relationship_name = type(r) RETURN r""")
|
r.relationship_name = type(r) RETURN r"""
|
||||||
|
)
|
||||||
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")
|
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")
|
||||||
|
|
||||||
nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()
|
nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()
|
||||||
|
|
|
||||||
|
|
@ -36,12 +36,12 @@ def test_AudioDocument():
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
|
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,12 @@ def test_ImageDocument():
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
|
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,12 @@ def test_PdfDocument():
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
|
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size):
|
||||||
for ground_truth, paragraph_data in zip(
|
for ground_truth, paragraph_data in zip(
|
||||||
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
|
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
|
||||||
):
|
):
|
||||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
assert (
|
||||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
ground_truth["word_count"] == paragraph_data.word_count
|
||||||
)
|
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
assert ground_truth["len_text"] == len(
|
||||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
paragraph_data.text
|
||||||
)
|
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
assert (
|
||||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||||
)
|
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||||
|
|
|
||||||
|
|
@ -71,32 +71,32 @@ def test_UnstructuredDocument():
|
||||||
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
|
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
||||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test DOCX
|
# Test DOCX
|
||||||
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
|
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
||||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_end" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_end != {paragraph_data.cut_type = }"
|
"sentence_end" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_end != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# TEST CSV
|
# TEST CSV
|
||||||
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
|
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
||||||
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
assert (
|
||||||
f"Read text doesn't match expected text: {paragraph_data.text}"
|
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
|
||||||
)
|
), f"Read text doesn't match expected text: {paragraph_data.text}"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test XLSX
|
# Test XLSX
|
||||||
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
|
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
||||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
assert (
|
||||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
"sentence_cut" == paragraph_data.cut_type
|
||||||
)
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,9 @@ async def test_deduplication():
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("data")
|
result = await relational_engine.get_all_data_from_table("data")
|
||||||
assert len(result) == 1, "More than one data entity was found."
|
assert len(result) == 1, "More than one data entity was found."
|
||||||
assert result[0]["name"] == "Natural_language_processing_copy", (
|
assert (
|
||||||
"Result name does not match expected value."
|
result[0]["name"] == "Natural_language_processing_copy"
|
||||||
)
|
), "Result name does not match expected value."
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("datasets")
|
result = await relational_engine.get_all_data_from_table("datasets")
|
||||||
assert len(result) == 2, "Unexpected number of datasets found."
|
assert len(result) == 2, "Unexpected number of datasets found."
|
||||||
|
|
@ -61,9 +61,9 @@ async def test_deduplication():
|
||||||
|
|
||||||
result = await relational_engine.get_all_data_from_table("data")
|
result = await relational_engine.get_all_data_from_table("data")
|
||||||
assert len(result) == 1, "More than one data entity was found."
|
assert len(result) == 1, "More than one data entity was found."
|
||||||
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
|
assert (
|
||||||
"Content hash is not a part of file name."
|
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
|
||||||
)
|
), "Content hash is not a part of file name."
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
|
||||||
|
|
@ -85,9 +85,9 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
assert not os.path.exists(get_relational_engine().db_path), (
|
assert not os.path.exists(
|
||||||
"SQLite relational database is not empty"
|
get_relational_engine().db_path
|
||||||
)
|
), "SQLite relational database is not empty"
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,9 +82,9 @@ async def main():
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
assert not os.path.exists(get_relational_engine().db_path), (
|
assert not os.path.exists(
|
||||||
"SQLite relational database is not empty"
|
get_relational_engine().db_path
|
||||||
)
|
), "SQLite relational database is not empty"
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
|
||||||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||||
# Get data entry from database based on hash contents
|
# Get data entry from database based on hash contents
|
||||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test deletion of data along with local files created by cognee
|
# Test deletion of data along with local files created by cognee
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert not os.path.exists(data.raw_data_location), (
|
assert not os.path.exists(
|
||||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||||
|
|
||||||
async with engine.get_async_session() as session:
|
async with engine.get_async_session() as session:
|
||||||
# Get data entry from database based on file path
|
# Get data entry from database based on file path
|
||||||
data = (
|
data = (
|
||||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||||
).one()
|
).one()
|
||||||
assert os.path.isfile(data.raw_data_location), (
|
assert os.path.isfile(
|
||||||
f"Data location doesn't exist: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||||
# Test local files not created by cognee won't get deleted
|
# Test local files not created by cognee won't get deleted
|
||||||
await engine.delete_data_entity(data.id)
|
await engine.delete_data_entity(data.id)
|
||||||
assert os.path.exists(data.raw_data_location), (
|
assert os.path.exists(
|
||||||
f"Data location doesn't exists: {data.raw_data_location}"
|
data.raw_data_location
|
||||||
)
|
), f"Data location doesn't exists: {data.raw_data_location}"
|
||||||
|
|
||||||
|
|
||||||
async def test_getting_of_documents(dataset_name_1):
|
async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
|
||||||
|
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||||
assert len(document_ids) == 1, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
len(document_ids) == 1
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||||
|
|
||||||
# Test getting of documents for search when no dataset is provided
|
# Test getting of documents for search when no dataset is provided
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
document_ids = await get_document_ids_for_user(user.id)
|
document_ids = await get_document_ids_for_user(user.id)
|
||||||
assert len(document_ids) == 2, (
|
assert (
|
||||||
f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
len(document_ids) == 2
|
||||||
)
|
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ batch_paragraphs_vals = [True, False]
|
||||||
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
|
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
|
||||||
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
|
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
|
||||||
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
|
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -36,9 +36,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
|
||||||
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
|
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
|
||||||
|
|
||||||
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
|
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
|
||||||
assert np.all(chunk_lengths <= paragraph_length), (
|
assert np.all(
|
||||||
f"{paragraph_length = }: {larger_chunks} are too large"
|
chunk_lengths <= paragraph_length
|
||||||
)
|
), f"{paragraph_length = }: {larger_chunks} are too large"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -50,6 +50,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_
|
||||||
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
|
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
|
||||||
)
|
)
|
||||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||||
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
assert np.all(
|
||||||
f"{chunk_indices = } are not monotonically increasing"
|
chunk_indices == np.arange(len(chunk_indices))
|
||||||
)
|
), f"{chunk_indices = } are not monotonically increasing"
|
||||||
|
|
|
||||||
|
|
@ -58,9 +58,9 @@ def run_chunking_test(test_text, expected_chunks):
|
||||||
|
|
||||||
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
|
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
|
||||||
for key in ["text", "word_count", "cut_type"]:
|
for key in ["text", "word_count", "cut_type"]:
|
||||||
assert chunk[key] == expected_chunks_item[key], (
|
assert (
|
||||||
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
chunk[key] == expected_chunks_item[key]
|
||||||
)
|
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
||||||
|
|
||||||
|
|
||||||
def test_chunking_whole_text():
|
def test_chunking_whole_text():
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,9 @@ maximum_length_vals = [None, 8, 64]
|
||||||
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
|
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
|
||||||
chunks = chunk_by_sentence(input_text, maximum_length)
|
chunks = chunk_by_sentence(input_text, maximum_length)
|
||||||
reconstructed_text = "".join([chunk[1] for chunk in chunks])
|
reconstructed_text = "".join([chunk[1] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -36,6 +36,6 @@ def test_paragraph_chunk_length(input_text, maximum_length):
|
||||||
chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks])
|
chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks])
|
||||||
|
|
||||||
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
|
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
|
||||||
assert np.all(chunk_lengths <= maximum_length), (
|
assert np.all(
|
||||||
f"{maximum_length = }: {larger_chunks} are too large"
|
chunk_lengths <= maximum_length
|
||||||
)
|
), f"{maximum_length = }: {larger_chunks} are too large"
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
|
||||||
def test_chunk_by_word_isomorphism(input_text):
|
def test_chunk_by_word_isomorphism(input_text):
|
||||||
chunks = chunk_by_word(input_text)
|
chunks = chunk_by_word(input_text)
|
||||||
reconstructed_text = "".join([chunk[0] for chunk in chunks])
|
reconstructed_text = "".join([chunk[0] for chunk in chunks])
|
||||||
assert reconstructed_text == input_text, (
|
assert (
|
||||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
reconstructed_text == input_text
|
||||||
)
|
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue