Source code for tick.prox.prox_multi
# License: BSD 3 clause
# -*- coding: utf8 -*-
import numpy as np
from .base import Prox
from .build.prox import ProxMultiDouble as _ProxMultiDouble
from .build.prox import ProxMultiFloat as _ProxMultiFloat
from tick.prox import ProxZero
__author__ = 'Stephane Gaiffas'
dtype_map = {
np.dtype("float64"): _ProxMultiDouble,
np.dtype("float32"): _ProxMultiFloat
}
[docs]class ProxMulti(Prox):
"""Multiple proximal operator. This allows to apply sequentially a list
of proximal operators. This is convenient when one wants to apply different
proximal operators on different parts of a vector.
Parameters
----------
proxs : `tuple` of `Prox`
A tuple of prox operators to be applied successively.
Attributes
----------
dtype : `{'float64', 'float32'}`
Type of the arrays used.
"""
_attrinfos = {"proxs": {"writable": False,}}
[docs] def __init__(self, proxs: tuple):
Prox.__init__(self, None)
if not proxs:
proxs = [ProxZero()]
dtype = proxs[0].dtype
self.dtype = dtype
for prox in proxs:
if not isinstance(prox, Prox):
raise ValueError('%s is not a Prox' % prox.__class__.__name__)
if not hasattr(prox, '_prox'):
raise ValueError('%s cannot be used in ProxMulti' % prox.name)
if prox._prox is None:
raise ValueError('%s cannot be used in ProxMulti' % prox.name)
if dtype != prox.dtype:
raise ValueError(
'ProxMulti can only handle proxes with same dtype')
# strength of ProxMulti is 0., since it's not used
self.proxs = [prox._prox for prox in proxs]
self._prox = self._build_cpp_prox(dtype)
def _call(self, coeffs: np.ndarray, step: object, out: np.ndarray):
self._prox.call(coeffs, step, out)
[docs] def value(self, coeffs: np.ndarray):
"""
Returns the value of the penalization at ``coeffs``.
This returns the sum of the values of each prox called on the same
coeffs.
Parameters
----------
coeffs : `numpy.ndarray`, shape=(n_coeffs,)
The value of the penalization is computed at this point
Returns
-------
output : `float`
Value of the penalization at ``coeffs``
"""
return self._prox.value(coeffs)
def _build_cpp_prox(self, dtype_or_object_with_dtype):
self.dtype = self._extract_dtype(dtype_or_object_with_dtype)
prox_class = self._get_typed_class(dtype_or_object_with_dtype,
dtype_map)
return prox_class(self.proxs)
def astype(self, dtype_or_object_with_dtype):
raise NotImplementedError(
"This type requires each Prox to their 'astype' called (for now)")