#!/usr/bin/env python
# Copyright 2014-2022 Sergio Oller <sergioller@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module implements map and starmap functions (see python standard
library to learn about them).
The implementations provided in this module allow providing additional
arguments to the mapped functions. Additionally they will initialize
the pool and close it automatically by default if possible.
The easiest way to learn is by reading the following examples.
===========
Examples
===========
Map example
===========
You want to do:
>>> y1 = [myfunction(x, argument1, argument2) for x in mylist]
In parallel:
>>> y2 = parmap.map(myfunction, mylist, argument1, argument2)
Check both results:
>>> assert y1 == y2
Starmap example
================
You want to do:
>>> z1 = [myfunction(x, y, argument1, argument2) for (x,y) in mylist]
In parallel:
>>> z2 = parmap.starmap(myfunction, mylist, argument1, argument2)
Check both results:
>>> assert z1 == z2
You want to do:
>>> listx = [1, 2, 3, 4, 5, 6]
>>> listy = [2, 3, 4, 5, 6, 7]
>>> a = 3.14
>>> b = 42
>>> listz1 = []
>>> for x in listx:
>>> for y in listy:
>>> listz1.append(myfunction(x, y, a, b))
In parallel:
>>> from itertools import product
>>> listz2 = parmap.starmap(myfunction, product(listx, listy), a, b)
Check both results:
>>> assert listz1 == listz2
========
Members
========
"""
# The original idea for this implementation was given by J.F. Sebastian
# at http://stackoverflow.com/a/5443941/446149
import multiprocessing
import typing as T
import warnings
from functools import partial
from itertools import repeat
from multiprocessing.pool import AsyncResult
try:
import tqdm.auto as tqdm # type: ignore
HAVE_TQDM = True
except ImportError:
HAVE_TQDM = False
def _func_star_single(func_item_args):
"""Equivalent to:
func = func_item_args[0]
item = func_item_args[1]
args = func_item_args[2]
kwargs = func_item_args[3]
return func(item,args[0],args[1],..., **kwargs)
"""
return func_item_args[0](
*[func_item_args[1]] + func_item_args[2], **func_item_args[3]
)
def _func_star_many(func_items_args):
"""Equivalent to:
func = func_item_args[0]
items = func_item_args[1]
args = func_item_args[2:]
kwargs = func_item_args[3]
return func(items[0], items[1], ..., args[0], args[1], ..., **kwargs)
"""
return func_items_args[0](
*list(func_items_args[1]) + func_items_args[2], **func_items_args[3]
)
def _create_pool(kwargs):
parallel: bool = kwargs.pop("pm_parallel", True)
pool: T.Optional[multiprocessing.Pool] = kwargs.pop("pm_pool", None)
close_pool = False
processes: T.Optional[int] = kwargs.pop("pm_processes", None)
# Initialize pool if parallel:
if parallel and pool is None:
try:
pool = multiprocessing.Pool(processes=processes)
close_pool = True
except Exception as exc: # Disable parallel on error:
warnings.warn(str(exc))
parallel = False
return parallel, pool, close_pool
def _do_pbar(async_result, num_tasks, chunksize, refresh_time, pbar_wrapper):
remaining = num_tasks
# tqdm provides a progress bar.
# the pbar needs to be updated with the increment on each
# iteration.
with pbar_wrapper(total=num_tasks) as pbar:
while True:
if async_result.ready():
pbar.update(remaining)
break
try:
remaining_now = async_result._number_left * chunksize
done_now = remaining - remaining_now
remaining = remaining_now
except:
break
if done_now > 0:
pbar.update(done_now)
async_result.wait(refresh_time) # update every two seconds
def _get_default_chunksize(chunksize, pool, num_tasks):
# default from multiprocessing
# https://github.com/python/cpython/blob/master/Lib/multiprocessing/pool.py
if chunksize is None:
chunksize, extra = divmod(num_tasks, len(pool._pool) * 4)
if extra:
chunksize += 1
return chunksize
def _prepare_pbar_wrapper(progress):
has_pbar = False
wrapper = None
if progress is True and HAVE_TQDM:
has_pbar = True
wrapper = tqdm.tqdm
elif isinstance(progress, dict) and HAVE_TQDM:
has_pbar = True
wrapper = partial(tqdm.tqdm, **progress)
elif isinstance(progress, T.Callable):
has_pbar = True
wrapper = progress
return (has_pbar, wrapper)
def _serial_map_or_starmap(
function, iterable, args, kwargs, pbar_wrapper, map_or_starmap
):
if pbar_wrapper is not None:
iterable = pbar_wrapper(iterable)
if map_or_starmap == "map":
output = [function(*([item] + list(args)), **kwargs) for item in iterable]
elif map_or_starmap == "starmap":
output = [function(*(list(item) + list(args)), **kwargs) for item in iterable]
else:
raise AssertionError(
"Internal parmap error: Invalid map_or_starmap." + " This should not happen"
)
return output
def _get_helper_func(map_or_starmap):
if map_or_starmap == "map":
func_star = _func_star_single
elif map_or_starmap == "starmap":
func_star = _func_star_many
else:
raise AssertionError(
"Internal parmap error: Invalid map_or_starmap." + " This should not happen"
)
return func_star
def _deprecated_kwargs(kwargs, arg_newarg):
"""arg_newarg is a list of tuples, where each tuple has a pair of strings.
('old_arg', 'new_arg')
A DeprecationWarning is raised for the arguments that need to be
replaced.
"""
warn_for = []
for (arg, new_kw) in arg_newarg:
if arg in kwargs.keys():
val = kwargs.pop(arg)
kwargs[new_kw] = val
warn_for.append((arg, new_kw))
if len(warn_for) > 0:
if len(warn_for) == 1:
warnings.warn(
"Argument '{}' is deprecated. Use {} instead".format(
warn_for[0][0], warn_for[0][1]
),
DeprecationWarning,
stacklevel=4,
)
else:
args = ", ".join([x[0] for x in warn_for])
repl = ", ".join([x[1] for x in warn_for])
warnings.warn(
"Arguments '{}' are deprecated. Use '{}' instead respectively".format(
args, repl
),
DeprecationWarning,
stacklevel=4,
)
return kwargs
def _map_or_starmap(function, iterable, args, kwargs, map_or_starmap):
"""
Shared function between parmap.map and parmap.starmap.
Refer to those functions for details.
"""
arg_newarg = (
("parallel", "pm_parallel"),
("chunksize", "pm_chunksize"),
("pool", "pm_pool"),
("processes", "pm_processes"),
("parmap_progress", "pm_pbar"),
)
kwargs = _deprecated_kwargs(kwargs, arg_newarg)
chunksize = kwargs.pop("pm_chunksize", None)
progress = kwargs.pop("pm_pbar", False)
(has_pbar, pbar_wrapper) = _prepare_pbar_wrapper(progress)
parallel, pool, close_pool = _create_pool(kwargs)
# Handle case: Execute sequentially:
if not parallel:
return _serial_map_or_starmap(
function, iterable, args, kwargs, pbar_wrapper, map_or_starmap
)
func_star = _get_helper_func(map_or_starmap)
# Handle case: Without showing progress bar
if not has_pbar:
try:
result = pool.map_async(
func_star,
zip(repeat(function), iterable, repeat(list(args)), repeat(kwargs)),
chunksize,
)
output = result.get()
except:
if close_pool:
pool.terminate()
raise
else:
if close_pool:
pool.close()
pool.join()
return output
# Handle case: Show progress bar:
try:
num_tasks = len(iterable)
# get a chunksize (as multiprocessing does):
chunksize = _get_default_chunksize(chunksize, pool, num_tasks)
# use map_async to get progress information
result = pool.map_async(
func_star,
zip(repeat(function), iterable, repeat(list(args)), repeat(kwargs)),
chunksize,
)
except:
if close_pool:
pool.terminate()
raise
else:
if close_pool:
pool.close()
# Progress bar:
try:
_do_pbar(
result, num_tasks, chunksize, refresh_time=2, pbar_wrapper=pbar_wrapper
)
finally:
output = result.get()
if close_pool:
pool.join()
return output
[docs]def map(function, iterable, *args, **kwargs):
"""This function is equivalent to:
>>> [function(x, args[0], args[1],...) for x in iterable]
:param pm_parallel: Force parallelization on/off
:type pm_parallel: bool
:param pm_chunksize: see :py:class:`multiprocessing.pool.Pool`
:type pm_chunksize: int
:param pm_pool: Pass an existing pool
:type pm_pool: multiprocessing.pool.Pool
:param pm_processes: Number of processes to use in the pool. See
:py:class:`multiprocessing.pool.Pool`
:type pm_processes: int
:param pm_pbar: Show progress bar with optional information.
* If it is a `boolean`, whether to show or not the progress bar.
* If it is a `dictionary`, these are options passed to `tqdm.tqdm()`.
* If it is a `callable`, the callable is a function compatible with `tqdm.tqdm()`.
If you want to pass additional options to your callable, consider using :py:func:`functools.partial`::
from functools import partial
from tqdm_loggable.auto import tqdm
parmap.map(print, range(10), pm_pbar = partial(tqdm, desc = "example"))
:type pm_pbar: bool, dict or callable
"""
return _map_or_starmap(function, iterable, args, kwargs, "map")
[docs]def starmap(function, iterables, *args, **kwargs):
"""Equivalent to:
>>> return ([function(x1,x2,x3,..., args[0], args[1],...) for
>>> (x1,x2,x3...) in iterable])
:param pm_parallel: Force parallelization on/off
:type pm_parallel: bool
:param pm_chunksize: see :py:class:`multiprocessing.pool.Pool`
:type pm_chunksize: int
:param pm_pool: Pass an existing pool
:type pm_pool: multiprocessing.pool.Pool
:param pm_processes: Number of processes to use in the pool. See
:py:class:`multiprocessing.pool.Pool`
:type pm_processes: int
:param pm_pbar: Show progress bar with optional information.
* If it is a `boolean`, whether to show or not the progress bar.
* If it is a `dictionary`, these are options passed to `tqdm.tqdm()`.
* If it is a `callable`, the callable is a function compatible with `tqdm.tqdm()`.
If you want to pass additional options to your callable, consider using :py:func:`functools.partial`::
from functools import partial
from tqdm_loggable.auto import tqdm
parmap.map(print, range(10), pm_pbar = partial(tqdm, desc = "example"))
:type pm_pbar: bool, dict or callable
"""
return _map_or_starmap(function, iterables, args, kwargs, "starmap")
class _DummyAsyncResult(AsyncResult):
"""AsyncResult compatible class, for when parallelization is disabled
It is a dummy class.
"""
def __init__(self, values):
self._values = values
@property
def _number_left(self):
return 0
def get(self, timeout=None):
return self._values
def wait(self, timeout=None):
pass
def ready(self):
return True
def successful(self):
# The exception would have been raised in the computation of result
return True
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
pass
class _ParallelAsyncResult(AsyncResult):
"""Like the AsyncResult, but it will close the pool when we leave the
``with`` block or when we check if it is ready.
"""
def __init__(self, result, pool=None):
self._result = result
self._pool = pool
@property
def _number_left(self):
return self._result._number_left
def get(self, timeout=None):
values = self._result.get(timeout)
if self.ready():
self.join()
return values
def wait(self, timeout=None):
return self._result.wait(timeout)
def ready(self):
is_ready = self._result.ready()
if is_ready:
self.join()
return is_ready
def successful(self):
return self._result.successful()
def __enter__(self):
return self
def join(self):
if self._pool is not None:
self._pool.join()
self._pool = None
def terminate(self):
if self._pool is not None:
self._pool.terminate()
self._pool = None
def __exit__(self, type, value, traceback):
self.terminate()
def _map_or_starmap_async(function, iterable, args, kwargs, map_or_starmap):
"""
Shared function between parmap.map_async and parmap.starmap_async.
Refer to those functions for details.
"""
arg_newarg = (
("parallel", "pm_parallel"),
("chunksize", "pm_chunksize"),
("pool", "pm_pool"),
("processes", "pm_processes"),
("callback", "pm_callback"),
("error_callback", "pm_error_callback"),
)
kwargs = _deprecated_kwargs(kwargs, arg_newarg)
chunksize = kwargs.pop("pm_chunksize", None)
callback = kwargs.pop("pm_callback", None)
error_callback = kwargs.pop("pm_error_callback", None)
parallel, pool, close_pool = _create_pool(kwargs)
# Map:
if parallel:
func_star = _get_helper_func(map_or_starmap)
try:
result = pool.map_async(
func_star,
zip(repeat(function), iterable, repeat(list(args)), repeat(kwargs)),
chunksize=chunksize,
callback=callback,
error_callback=error_callback,
)
except:
if close_pool:
pool.terminate()
raise
else:
if close_pool:
pool.close()
result = _ParallelAsyncResult(result, pool)
else:
result = _ParallelAsyncResult(result)
else:
values = _serial_map_or_starmap(
function, iterable, args, kwargs, None, map_or_starmap
)
result = _DummyAsyncResult(values)
return result
[docs]def map_async(function, iterable, *args, **kwargs):
"""This function is the multiprocessing.Pool.map_async version that
supports multiple arguments.
>>> [function(x, args[0], args[1],...) for x in iterable]
:param pm_parallel: Force parallelization on/off. If False, the
function won't be asynchronous.
:type pm_parallel: bool
:param pm_chunksize: see :py:class:`multiprocessing.pool.Pool`
:type pm_chunksize: int
:param pm_callback: see :py:class:`multiprocessing.pool.Pool`
:type pm_callback: function
:param pm_error_callback: (not on python 2) see
:py:class:`multiprocessing.pool.Pool`
:type pm_error_callback: function
:param pm_pool: Pass an existing pool.
:type pm_pool: multiprocessing.pool.Pool
:param pm_processes: Number of processes to use in the pool. See
:py:class:`multiprocessing.pool.Pool`
:type pm_processes: int
"""
return _map_or_starmap_async(function, iterable, args, kwargs, "map")
[docs]def starmap_async(function, iterables, *args, **kwargs):
"""This function is the multiprocessing.Pool.starmap_async version that
supports multiple arguments.
>>> return ([function(x1,x2,x3,..., args[0], args[1],...) for
>>> (x1,x2,x3...) in iterable])
:param pm_parallel: Force parallelization on/off. If False, the
function won't be asynchronous.
:type pm_parallel: bool
:param pm_chunksize: see :py:class:`multiprocessing.pool.Pool`
:type pm_chunksize: int
:param pm_callback: see :py:class:`multiprocessing.pool.Pool`
:type pm_callback: function
:param pm_error_callback: see :py:class:`multiprocessing.pool.Pool`
:type pm_error_callback: function
:param pm_pool: Pass an existing pool.
:type pm_pool: multiprocessing.pool.Pool
:param pm_processes: Number of processes to use in the pool. See
:py:class:`multiprocessing.pool.Pool`
:type pm_processes: int
"""
return _map_or_starmap_async(function, iterables, args, kwargs, "starmap")
# Needs testing, but it might work as it is:
# def _serial_imap_or_istarmap(function, iterable, args,
# kwargs, map_or_starmap):
# if map_or_starmap == "map":
# output = (function(*([item] + list(args)), **kwargs)
# for item in iterable)
# elif map_or_starmap == "starmap":
# output = (function(*(list(item) + list(args)), **kwargs)
# for item in iterable)
# else:
# raise AssertionError("Internal parmap error: " +
# "Invalid map_or_starmap. This should not happen")
# return output
# def imap(function, iterable, *args, **kwargs):
# chunksize = kwargs.pop("pm_chunksize", 1)
# parallel, pool, close_pool = _create_pool(kwargs)
# # Map:
# if parallel:
# func_star = _get_helper_func("map")
# try:
# output = pool.imap(func_star,
# zip(repeat(function), iterable,
# repeat(list(args)), repeat(kwargs)),
# chunksize)
# finally:
# if close_pool:
# pool.close()
# pool.join()
# else:
# output = _serial_imap_or_istarmap(function, iterable, args, kwargs,
# map_or_starmap)
# return output