fix: memify coding agent pipeline

This commit is contained in:
Boris Arzentar 2025-10-31 12:25:46 +01:00
parent 737b6dc634
commit 94d2ca01a7
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
3 changed files with 16 additions and 13 deletions

View file

@ -35,11 +35,13 @@ async def handle_task(
[key == "context" for key in inspect.signature(running_task.executable).parameters.keys()]
)
kwargs = {}
if has_context:
args.append(context)
kwargs["context"] = context
try:
async for result_data in running_task.execute(args, next_task_batch_size):
async for result_data in running_task.execute(args, kwargs, next_task_batch_size):
async for result in run_tasks_base(leftover_tasks, result_data, user, context):
yield result

View file

@ -49,10 +49,10 @@ class Task:
return self.executable(*combined_args, **combined_kwargs)
async def execute_async_generator(self, args):
async def execute_async_generator(self, args, kwargs):
"""Execute async generator task and collect results in batches."""
results = []
async_iterator = self.run(*args)
async_iterator = self.run(*args, **kwargs)
async for partial_result in async_iterator:
results.append(partial_result)
@ -64,11 +64,11 @@ class Task:
if results:
yield results
async def execute_generator(self, args):
async def execute_generator(self, args, kwargs):
"""Execute generator task and collect results in batches."""
results = []
for partial_result in self.run(*args):
for partial_result in self.run(*args, **kwargs):
results.append(partial_result)
if len(results) == self._next_batch_size:
@ -78,20 +78,20 @@ class Task:
if results:
yield results
async def execute_coroutine(self, args):
async def execute_coroutine(self, args, kwargs):
"""Execute coroutine task and yield the result."""
task_result = await self.run(*args)
task_result = await self.run(*args, **kwargs)
yield task_result
async def execute_function(self, args):
async def execute_function(self, args, kwargs):
"""Execute function task and yield the result."""
task_result = self.run(*args)
task_result = self.run(*args, **kwargs)
yield task_result
async def execute(self, args, next_batch_size=None):
async def execute(self, args, kwargs, next_batch_size=None):
"""Execute the task based on its type and yield results with the next task's batch size."""
if next_batch_size is not None:
self._next_batch_size = next_batch_size
async for result in self._execute_method(args):
async for result in self._execute_method(args, kwargs):
yield result

View file

@ -126,7 +126,8 @@ async def add_rule_associations(
if len(edges_to_save) > 0:
await graph_engine.add_edges(edges_to_save)
if context:
if context and hasattr(context["data"], "id"):
await upsert_edges(
edges_to_save,
user_id=context["user"].id,