Source code for roastcoffea.visualization.plots.throughput

"""Throughput and data rate plots.

Visualizations for data processing rates and event throughput.
"""

from __future__ import annotations

import datetime
from pathlib import Path
from typing import Any

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np

from roastcoffea.visualization.utils import (
    add_worker_count_annotation,
    finalize_timeline_plot,
    setup_timeline_axes,
    validate_tracking_data,
)


[docs] def plot_worker_activity_timeline( tracking_data: dict[str, Any] | None, output_path: Path | None = None, figsize: tuple[int, int] = (12, 6), title: str = "Worker Activity Over Time", max_legend_entries: int = 5, ) -> tuple[plt.Figure, plt.Axes]: """Plot active tasks per worker over time. Shows the number of active (processing + queued) tasks per worker, which indicates overall workload distribution. Parameters ---------- tracking_data : dict or None Tracking data with worker_active_tasks output_path : Path, optional Save path figsize : tuple Figure size title : str Plot title max_legend_entries : int, optional Maximum number of workers to show in legend. Default is 5. Returns ------- fig, ax : Figure and Axes Matplotlib figure and axes Raises ------ ValueError If tracking_data is None or missing active tasks data """ worker_active_tasks = validate_tracking_data( tracking_data, "worker_active_tasks", "No worker active tasks data available" ) fig, ax = plt.subplots(figsize=figsize) for worker_id, timeline in worker_active_tasks.items(): if timeline: timestamps = [t for t, _ in timeline] values = [val for _, val in timeline] ax.plot(timestamps, values, label=worker_id, alpha=0.7, linewidth=2) setup_timeline_axes(ax, ylabel="Number of Active Tasks", title=title) num_workers = len(worker_active_tasks) if num_workers <= max_legend_entries: ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1), fontsize=8) else: add_worker_count_annotation(ax, num_workers) finalize_timeline_plot(fig, ax, output_path) return fig, ax
[docs] def plot_total_active_tasks_timeline( tracking_data: dict[str, Any] | None, output_path: Path | None = None, figsize: tuple[int, int] = (10, 5), title: str = "Total Active Tasks Over Time", ) -> tuple[plt.Figure, plt.Axes]: """Plot total active tasks across all workers over time. Aggregates active tasks from all workers to show overall cluster activity. Parameters ---------- tracking_data : dict or None Tracking data with worker_active_tasks output_path : Path, optional Save path figsize : tuple Figure size title : str Plot title Returns ------- fig, ax : Figure and Axes Matplotlib figure and axes Raises ------ ValueError If tracking_data is None or missing active tasks data """ worker_active_tasks = validate_tracking_data( tracking_data, "worker_active_tasks", "No worker active tasks data available" ) # Aggregate across all workers at each timestamp timestamp_totals: dict = {} for _worker_id, timeline in worker_active_tasks.items(): for timestamp, task_count in timeline: if timestamp not in timestamp_totals: timestamp_totals[timestamp] = 0 timestamp_totals[timestamp] += task_count # Sort by timestamp sorted_items = sorted(timestamp_totals.items()) timestamps = [t for t, _ in sorted_items] totals = [c for _, c in sorted_items] fig, ax = plt.subplots(figsize=figsize) ax.plot(timestamps, totals, linewidth=2, color="steelblue") ax.fill_between(timestamps, totals, alpha=0.3, color="steelblue") setup_timeline_axes(ax, ylabel="Total Active Tasks", title=title) finalize_timeline_plot(fig, ax, output_path) return fig, ax
[docs] def plot_throughput_timeline( chunk_info: dict[tuple[str, int, int], tuple[float, float, int]], tracking_data: dict[str, Any] | None = None, output_path: Path | None = None, figsize: tuple[int, int] = (12, 6), title: str = "Data Throughput Over Time", ) -> tuple[plt.Figure, plt.Axes]: """Plot instantaneous data throughput (Gbps) over time. Computes the instantaneous data rate at each sample point by finding all chunks that were being processed at that moment and summing their individual throughputs. Optionally overlays worker count on a secondary y-axis if tracking_data is provided. Parameters ---------- chunk_info : dict Per-chunk timing data from metrics. Format: {(filename, start, stop): (t0, t1, bytesread)} tracking_data : dict, optional Worker tracking data with worker_counts for overlay plot output_path : Path, optional Path to save figure figsize : tuple Figure size (width, height) title : str Plot title Returns ------- fig, ax : Figure and Axes Matplotlib figure and axes (returns primary axes) Raises ------ ValueError If chunk_info is empty """ if not chunk_info: msg = "No chunk_info provided. Pass chunk_info parameter from metrics." raise ValueError(msg) # Extract per-chunk timing: starts, ends, bytes, runtimes starts_list = [] ends_list = [] bytes_read_list = [] runtimes_list = [] for _key, (t0, t1, bytesread) in chunk_info.items(): starts_list.append(t0) ends_list.append(t1) bytes_read_list.append(bytesread) runtimes_list.append(t1 - t0) # Convert to numpy arrays starts = np.array(starts_list) ends = np.array(ends_list) bytes_read = np.array(bytes_read_list) runtimes = np.array(runtimes_list) # Determine time range t_min = min(starts) t_max = max(ends) # Generate sample timestamps (100 points across the run) sample_times_epoch = np.linspace(t_min, t_max, num=100) sample_times_dt = [datetime.datetime.fromtimestamp(t) for t in sample_times_epoch] # Calculate instantaneous rate at each sample point instantaneous_rates = [] for t in sample_times_epoch: # Find chunks active at this timestamp mask = (starts <= t) & (t <= ends) if mask.any(): # Sum up throughput of all active chunks # Each chunk's instantaneous rate = bytes / runtime active_bytes = bytes_read[mask] active_runtimes = runtimes[mask] # Avoid division by zero valid = active_runtimes > 0 if valid.any(): rate_Gbps = np.sum( active_bytes[valid] * 8 / 1e9 / active_runtimes[valid] ) else: rate_Gbps = 0.0 else: rate_Gbps = 0.0 instantaneous_rates.append(rate_Gbps) fig, ax1 = plt.subplots(figsize=figsize) ax1.fill_between( np.array(sample_times_dt), instantaneous_rates, alpha=0.5, color="C1", edgecolor="C1", linewidth=0.5, ) ax1.set_xlabel("Time") ax1.set_ylabel("Data Rate (Gbps)", color="C1") ax1.tick_params(axis="y", labelcolor="C1") ax1.set_ylim((0, max(instantaneous_rates) * 1.1 if instantaneous_rates else 1)) ax1.grid(True, alpha=0.3) # Overlay worker count if available if tracking_data and "worker_counts" in tracking_data: worker_counts = tracking_data["worker_counts"] if worker_counts: timestamps = [t for t, _ in worker_counts.items()] counts = [c for _, c in worker_counts.items()] ax2 = ax1.twinx() ax2.plot(timestamps, counts, linewidth=2, color="C0", label="Workers") ax2.set_ylabel("Number of Workers", color="C0") ax2.tick_params(axis="y", labelcolor="C0") ax2.set_ylim((0, max(counts) * 1.1)) ax1.set_title(title) # Format x-axis ax1.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S")) plt.xticks(rotation=45) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=150, bbox_inches="tight") return fig, ax1