fix: memify coding agent pipeline

This commit is contained in:
Boris Arzentar 2025-10-31 12:25:46 +01:00
parent 737b6dc634
commit 94d2ca01a7
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
3 changed files with 16 additions and 13 deletions

View file

@ -35,11 +35,13 @@ async def handle_task(
[key == "context" for key in inspect.signature(running_task.executable).parameters.keys()] [key == "context" for key in inspect.signature(running_task.executable).parameters.keys()]
) )
kwargs = {}
if has_context: if has_context:
args.append(context) kwargs["context"] = context
try: try:
async for result_data in running_task.execute(args, next_task_batch_size): async for result_data in running_task.execute(args, kwargs, next_task_batch_size):
async for result in run_tasks_base(leftover_tasks, result_data, user, context): async for result in run_tasks_base(leftover_tasks, result_data, user, context):
yield result yield result

View file

@ -49,10 +49,10 @@ class Task:
return self.executable(*combined_args, **combined_kwargs) return self.executable(*combined_args, **combined_kwargs)
async def execute_async_generator(self, args): async def execute_async_generator(self, args, kwargs):
"""Execute async generator task and collect results in batches.""" """Execute async generator task and collect results in batches."""
results = [] results = []
async_iterator = self.run(*args) async_iterator = self.run(*args, **kwargs)
async for partial_result in async_iterator: async for partial_result in async_iterator:
results.append(partial_result) results.append(partial_result)
@ -64,11 +64,11 @@ class Task:
if results: if results:
yield results yield results
async def execute_generator(self, args): async def execute_generator(self, args, kwargs):
"""Execute generator task and collect results in batches.""" """Execute generator task and collect results in batches."""
results = [] results = []
for partial_result in self.run(*args): for partial_result in self.run(*args, **kwargs):
results.append(partial_result) results.append(partial_result)
if len(results) == self._next_batch_size: if len(results) == self._next_batch_size:
@ -78,20 +78,20 @@ class Task:
if results: if results:
yield results yield results
async def execute_coroutine(self, args): async def execute_coroutine(self, args, kwargs):
"""Execute coroutine task and yield the result.""" """Execute coroutine task and yield the result."""
task_result = await self.run(*args) task_result = await self.run(*args, **kwargs)
yield task_result yield task_result
async def execute_function(self, args): async def execute_function(self, args, kwargs):
"""Execute function task and yield the result.""" """Execute function task and yield the result."""
task_result = self.run(*args) task_result = self.run(*args, **kwargs)
yield task_result yield task_result
async def execute(self, args, next_batch_size=None): async def execute(self, args, kwargs, next_batch_size=None):
"""Execute the task based on its type and yield results with the next task's batch size.""" """Execute the task based on its type and yield results with the next task's batch size."""
if next_batch_size is not None: if next_batch_size is not None:
self._next_batch_size = next_batch_size self._next_batch_size = next_batch_size
async for result in self._execute_method(args): async for result in self._execute_method(args, kwargs):
yield result yield result

View file

@ -126,7 +126,8 @@ async def add_rule_associations(
if len(edges_to_save) > 0: if len(edges_to_save) > 0:
await graph_engine.add_edges(edges_to_save) await graph_engine.add_edges(edges_to_save)
if context:
if context and hasattr(context["data"], "id"):
await upsert_edges( await upsert_edges(
edges_to_save, edges_to_save,
user_id=context["user"].id, user_id=context["user"].id,