Source code for cyto.tasks.manager
import importlib
import networkx as nx
import yaml
import os
from typing import Any, Dict, List, Optional
from cyto.tasks.graph import get_execution_order, get_task_by_name
from cyto.tasks.definitions import Task
[docs]
class TaskManager:
"""
Manages the execution of a pipeline of tasks based on their dependency graph.
"""
def __init__(self, graph: nx.DiGraph, resources_config: Optional[Dict] = None, profile: str = "default", verbose: bool = True):
"""
Initialize the TaskManager with a task dependency graph.
Args:
graph (nx.DiGraph): The task dependency graph
resources_config (Optional[Dict]): Resources configuration for execution profiles
profile (str): Execution profile to use from resources config
verbose (bool): Whether to print execution progress
"""
self.graph = graph
self.resources_config = resources_config or {}
self.profile = profile
self.verbose = verbose
self.results = {} # Store results from each task
def _log_message(self, message: str) -> None:
"""Log message if verbose is enabled"""
if self.verbose:
print(f"[TaskManager] {message}")
def _import_and_instantiate_task(self, task_def: Task) -> Any:
"""
Import the task class and instantiate it with the given parameters.
Args:
task_def (Task): Task definition containing module, params, etc.
Returns:
Any: The instantiated task object
"""
try:
# Import the module
module_path, class_name = task_def.module.rsplit('.', 1)
module = importlib.import_module(module_path)
task_class = getattr(module, class_name)
# Get execution config from resources file or task definition
execution_config = self._get_execution_config(task_def.name)
# Add execution_config to params
params = task_def.params.copy()
if execution_config:
params['execution_config'] = execution_config
# Instantiate the task
task_instance = task_class(**params)
return task_instance
except ImportError as e:
raise ImportError(f"Could not import module '{task_def.module}': {e}")
except AttributeError as e:
raise AttributeError(f"Class '{class_name}' not found in module '{module_path}': {e}")
except Exception as e:
raise RuntimeError(f"Failed to instantiate task '{task_def.name}': {e}")
def _prepare_task_data(self, task_def: Task, initial_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Prepare the input data for a task by combining initial data with results
from dependency tasks.
Args:
task_def (Task): The task definition
initial_data (Dict[str, Any]): The initial input data
Returns:
Dict[str, Any]: The prepared data for the task
"""
# Start with initial data
task_data = initial_data.copy()
# If this task has dependencies, merge their results
for dependency in task_def.dependencies:
if dependency in self.results:
dep_result = self.results[dependency]
# Merge dependency results into task data
# Later results override earlier ones if there are key conflicts
if isinstance(dep_result, dict):
task_data.update(dep_result)
else:
# If dependency result is not a dict, store it with the dependency name as key
task_data[dependency] = dep_result
return task_data
[docs]
def execute(self, initial_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute all tasks in the pipeline according to their dependency order.
Args:
initial_data (Dict[str, Any]): Initial data to pass to the first tasks
Returns:
Dict[str, Any]: Results from the final task(s) in the pipeline
"""
self._log_message("Starting pipeline execution")
# Get the execution order
execution_order = get_execution_order(self.graph)
self._log_message(f"Execution order: {execution_order}")
# Execute tasks in order
for task_name in execution_order:
self._log_message(f"Executing task: {task_name}")
# Get task definition
task_def = get_task_by_name(self.graph, task_name)
# Instantiate the task if not already done
if task_def.instance is None:
task_def.instance = self._import_and_instantiate_task(task_def)
# Prepare input data for this task
task_data = self._prepare_task_data(task_def, initial_data)
# Execute the task
try:
result = task_def.instance(task_data)
self.results[task_name] = result
self._log_message(f"Task '{task_name}' completed successfully")
except Exception as e:
error_msg = f"Task '{task_name}' failed: {e}"
self._log_message(error_msg)
raise RuntimeError(error_msg) from e
self._log_message("Pipeline execution completed")
# Return the result from the last task (or combined results if multiple end tasks)
if len(execution_order) == 1:
return self.results[execution_order[-1]]
else:
# If there are multiple final tasks, return all results
final_tasks = [task for task in execution_order
if not list(self.graph.successors(task))]
if len(final_tasks) == 1:
return self.results[final_tasks[0]]
else:
return {task: self.results[task] for task in final_tasks}
[docs]
def get_task_results(self) -> Dict[str, Any]:
"""
Get all task results.
Returns:
Dict[str, Any]: Dictionary mapping task names to their results
"""
return self.results.copy()
[docs]
def get_task_result(self, task_name: str) -> Any:
"""
Get the result from a specific task.
Args:
task_name (str): Name of the task
Returns:
Any: The task result
Raises:
KeyError: If the task hasn't been executed or doesn't exist
"""
if task_name not in self.results:
raise KeyError(f"No result found for task '{task_name}'. "
"Either the task hasn't been executed or doesn't exist.")
return self.results[task_name]