Source code for cyto.runners.docker
import cloudpickle
import tempfile
import os
from typing import Any
import docker
from docker.types import Mount
from cyto.runners.base import RunnerBase
import logging
from pathlib import Path
[docs]
class DockerRunner(RunnerBase):
"""Runs tasks in a Docker container using the docker-py library."""
def _find_dockerfile_path(self, image_name: str) -> str:
"""
Find the Dockerfile path for a given image name.
Args:
image_name (str): Docker image name (e.g., 'cellpose:latest')
Returns:
str: Path to the Dockerfile directory, or None if not found
"""
# Extract base name from image (remove tag)
base_name = image_name.split(':')[0]
# Common Dockerfile locations to check
possible_paths = [
f"containers/segmentation/{base_name}",
f"containers/{base_name}",
f"docker/{base_name}",
f"dockerfiles/{base_name}",
]
# Get the project root (assuming we're in cyto/runners/)
project_root = Path(__file__).parent.parent.parent
for path in possible_paths:
dockerfile_dir = project_root / path
dockerfile_path = dockerfile_dir / "Dockerfile"
if dockerfile_path.exists():
return str(dockerfile_dir)
return None
def _build_image_if_needed(self, client: docker.DockerClient, image_name: str) -> None:
"""
Build Docker image if it doesn't exist and a Dockerfile is found.
Args:
client (docker.DockerClient): Docker client
image_name (str): Docker image name to build
"""
try:
# Check if image exists
client.images.get(image_name)
print(f"Docker image '{image_name}' found locally.")
return
except docker.errors.ImageNotFound:
print(f"Docker image '{image_name}' not found locally. Searching for Dockerfile...")
# Find Dockerfile
dockerfile_dir = self._find_dockerfile_path(image_name)
if not dockerfile_dir:
raise ValueError(f"Docker image '{image_name}' not found and no Dockerfile found in common locations.")
print(f"Found Dockerfile at: {dockerfile_dir}")
print(f"Building Docker image '{image_name}'...")
try:
# Build the image with real-time log streaming
print("Building Docker image... This may take a while.")
print("=" * 50)
# Get the project root for build context
project_root = Path(__file__).parent.parent.parent
dockerfile_path = Path(dockerfile_dir) / "Dockerfile"
build_logs = client.api.build(
path=str(project_root), # Use project root as build context
dockerfile=str(dockerfile_path.relative_to(project_root)), # Relative path to Dockerfile
tag=image_name,
rm=True, # Remove intermediate containers
forcerm=True, # Always remove intermediate containers
decode=True, # Decode JSON logs for real-time streaming
pull=True # Pull base images if needed
)
# Stream build logs in real-time
for log in build_logs:
if 'stream' in log:
# Print each build step as it happens
message = log['stream'].strip()
if message: # Only print non-empty messages
print(message)
elif 'status' in log:
# Print status updates (like pulling images)
status = log['status']
if 'id' in log:
print(f"{status}: {log['id']}")
else:
print(status)
elif 'error' in log:
# Print any errors immediately
print(f"ERROR: {log['error']}")
print("=" * 50)
print(f"Successfully built Docker image '{image_name}'")
except docker.errors.BuildError as e:
error_msg = f"Failed to build Docker image '{image_name}':\n"
for log in e.build_log:
if 'stream' in log:
error_msg += log['stream']
elif 'error' in log:
error_msg += f"ERROR: {log['error']}\n"
raise RuntimeError(error_msg) from e
except Exception as e:
# Handle any other build-related errors
raise RuntimeError(f"Failed to build Docker image '{image_name}': {str(e)}") from e
[docs]
def run(self, task: Any, data: Any) -> Any:
"""
Serializes task and data, runs them in a Docker container, and
deserializes the result.
"""
image_name = self.execution_config.get("image")
if not image_name:
raise ValueError("Docker image name must be specified in execution_config")
client = docker.from_env()
# Build image if it doesn't exist
self._build_image_if_needed(client, image_name)
with tempfile.TemporaryDirectory() as temp_dir:
task_path = os.path.join(temp_dir, "task.pkl")
data_path = os.path.join(temp_dir, "data.pkl")
result_path = os.path.join(temp_dir, "result.pkl")
# Serialize task and data
with open(task_path, "wb") as f:
cloudpickle.dump(task, f)
with open(data_path, "wb") as f:
cloudpickle.dump(data, f)
# Define the mount to share the temp directory with the container
mount = Mount(target="/app/data", source=temp_dir, type="bind")
# Define the command to be executed inside the container
command = [
"python", "-m", "cyto.runners.container_worker",
"/app/data/task.pkl",
"/app/data/data.pkl",
"/app/data/result.pkl"
]
try:
print(f"Running container with image '{image_name}'...")
print("=" * 50)
# Run the container with detach=True to get a container object for streaming
container = client.containers.run(
image=image_name,
command=command,
mounts=[mount],
remove=False, # Don't auto-remove so we can get logs
detach=True, # Run detached to stream logs
stdout=True,
stderr=True
)
# Stream logs in real-time
try:
for log in container.logs(stream=True, stdout=True, stderr=True):
# Decode and print each log line as it comes
log_line = log.decode('utf-8').rstrip()
if log_line: # Only print non-empty lines
print(log_line)
# Wait for container to finish and get exit code
result = container.wait()
exit_code = result['StatusCode']
if exit_code != 0:
# Get any remaining logs if there was an error
error_logs = container.logs(stdout=False, stderr=True).decode('utf-8')
raise RuntimeError(f"Container execution failed with exit code {exit_code}:\n{error_logs}")
print("=" * 50)
print("Container execution completed successfully")
finally:
# Always remove the container when done
try:
container.remove()
except Exception as cleanup_error:
print(f"Warning: Failed to remove container: {cleanup_error}")
except docker.errors.ContainerError as e:
raise RuntimeError(f"Container execution failed:\n{e.stderr.decode()}") from e
except docker.errors.APIError as e:
raise RuntimeError(f"Docker API error: {e}") from e
# Deserialize the result
if not os.path.exists(result_path):
raise FileNotFoundError("Result file was not created by the container worker.")
with open(result_path, "rb") as f:
result = cloudpickle.load(f)
return result