mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
156 lines
5.2 KiB
Python
156 lines
5.2 KiB
Python
"""Thing."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|
from dataclasses import dataclass
|
|
from multiprocessing import cpu_count
|
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Callable, Mapping, Sequence
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
R = TypeVar("R")
|
|
|
|
modes = Literal["normal", "early_error"]
|
|
|
|
|
|
@dataclass
|
|
class ExecutorResults[R]:
|
|
"""Dataclass to store the results and exceptions of the parallel execution."""
|
|
|
|
results: list[R]
|
|
exceptions: list[BaseException]
|
|
|
|
def __repr__(self) -> str:
|
|
"""Return a string representation of the object."""
|
|
return f"results={self.results} exceptions={self.exceptions}"
|
|
|
|
|
|
def _parallelize_base[R](
|
|
executor_type: type[ThreadPoolExecutor | ProcessPoolExecutor],
|
|
func: Callable[..., R],
|
|
kwargs_list: Sequence[Mapping[str, Any]],
|
|
max_workers: int | None,
|
|
progress_tracker: int | None,
|
|
mode: modes,
|
|
) -> ExecutorResults:
|
|
total_work = len(kwargs_list)
|
|
|
|
with executor_type(max_workers=max_workers) as executor:
|
|
futures = [executor.submit(func, **kwarg) for kwarg in kwargs_list]
|
|
|
|
results = []
|
|
exceptions = []
|
|
for index, future in enumerate(futures, 1):
|
|
if exception := future.exception():
|
|
logger.error(f"{future} raised {exception.__class__.__name__}")
|
|
exceptions.append(exception)
|
|
if mode == "early_error":
|
|
executor.shutdown(wait=False)
|
|
raise exception
|
|
continue
|
|
|
|
results.append(future.result())
|
|
|
|
if progress_tracker and index % progress_tracker == 0:
|
|
logger.info(f"Progress: {index}/{total_work}")
|
|
|
|
return ExecutorResults(results, exceptions)
|
|
|
|
|
|
def parallelize_thread[R](
|
|
func: Callable[..., R],
|
|
kwargs_list: Sequence[Mapping[str, Any]],
|
|
max_workers: int | None = None,
|
|
progress_tracker: int | None = None,
|
|
mode: modes = "normal",
|
|
) -> ExecutorResults:
|
|
"""Generic function to run a function with multiple arguments in threads.
|
|
|
|
Args:
|
|
func (Callable[..., R]): Function to run in threads.
|
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
|
max_workers (int, optional): Number of workers to use. Defaults to 8.
|
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
|
|
|
Returns:
|
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
|
"""
|
|
return _parallelize_base(
|
|
executor_type=ThreadPoolExecutor,
|
|
func=func,
|
|
kwargs_list=kwargs_list,
|
|
max_workers=max_workers,
|
|
progress_tracker=progress_tracker,
|
|
mode=mode,
|
|
)
|
|
|
|
|
|
def parallelize_process[R](
|
|
func: Callable[..., R],
|
|
kwargs_list: Sequence[Mapping[str, Any]],
|
|
max_workers: int | None = None,
|
|
progress_tracker: int | None = None,
|
|
mode: modes = "normal",
|
|
) -> ExecutorResults:
|
|
"""Generic function to run a function with multiple arguments in process.
|
|
|
|
Args:
|
|
func (Callable[..., R]): Function to run in process.
|
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
|
max_workers (int, optional): Number of workers to use. Defaults to 4.
|
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
|
|
|
Returns:
|
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
|
"""
|
|
if max_workers and max_workers > cpu_count():
|
|
error = f"max_workers must be less than or equal to {cpu_count()}"
|
|
raise RuntimeError(error)
|
|
|
|
return process_executor_unchecked(
|
|
func=func,
|
|
kwargs_list=kwargs_list,
|
|
max_workers=max_workers,
|
|
progress_tracker=progress_tracker,
|
|
mode=mode,
|
|
)
|
|
|
|
|
|
def process_executor_unchecked[R](
|
|
func: Callable[..., R],
|
|
kwargs_list: Sequence[Mapping[str, Any]],
|
|
max_workers: int | None,
|
|
progress_tracker: int | None,
|
|
mode: modes = "normal",
|
|
) -> ExecutorResults:
|
|
"""Generic function to run a function with multiple arguments in parallel.
|
|
|
|
Note: this function does not check if the number of workers is greater than the number of CPUs.
|
|
This can cause the system to become unresponsive.
|
|
|
|
Args:
|
|
func (Callable[..., R]): Function to run in parallel.
|
|
kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
|
|
max_workers (int, optional): Number of workers to use. Defaults to 8.
|
|
progress_tracker (int, optional): Number of tasks to complete before logging progress.
|
|
mode (modes, optional): Mode to use. Defaults to "normal".
|
|
|
|
Returns:
|
|
tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
|
|
"""
|
|
return _parallelize_base(
|
|
executor_type=ProcessPoolExecutor,
|
|
func=func,
|
|
kwargs_list=kwargs_list,
|
|
max_workers=max_workers,
|
|
progress_tracker=progress_tracker,
|
|
mode=mode,
|
|
)
|