Source code for tick.prox.prox_equality
# License: BSD 3 clause
# -*- coding: utf8 -*-
import numpy as np
import sys
from .base import Prox
from .build.prox import ProxEqualityDouble as _ProxEqualityDouble
from .build.prox import ProxEqualityFloat as _ProxEqualityFloat
__author__ = 'Stephane Gaiffas'
dtype_map = {
np.dtype("float64"): _ProxEqualityDouble,
np.dtype("float32"): _ProxEqualityFloat
}
[docs]class ProxEquality(Prox):
"""Projection operator onto the set of vector with all coordinates equal
(or in the given range if given one).
Namely, this simply replaces all coordinates by their average
Parameters
----------
strength : `float`, default=0.
Not used in this prox, but kept for compatibility issues
range : `tuple` of two `int`, default=`None`
Range on which the prox is applied. If `None` then the prox is
applied on the whole vector
positive : `bool`, default=`False`
If True, ensures that the output of the prox has only non-negative
entries (in the given range)
Attributes
----------
dtype : `{'float64', 'float32'}`
Type of the arrays used.
"""
_attrinfos = {"positive": {"writable": True, "cpp_setter": "set_positive"}}
[docs] def __init__(self, strength: float = 0, range: tuple = None,
positive: bool = False):
Prox.__init__(self, range)
self.positive = positive
self._prox = self._build_cpp_prox("float64")
def _call(self, coeffs: np.ndarray, step: object, out: np.ndarray):
self._prox.call(coeffs, step, out)
[docs] def value(self, coeffs: np.ndarray):
"""
Simply returns 0 if all coeffs in range are equal. Other wise returns
infinity. This is not a penalization but a projection.
Parameters
----------
coeffs : `numpy.ndarray`, shape=(n_coeffs,)
Vector to be projected
Returns
-------
output : `float`
Returns 0 or np.inf
"""
raw_value = self._prox.value(coeffs)
if raw_value == sys.float_info.max:
return np.inf
else:
return 0
@property
def strength(self):
return None
@strength.setter
def strength(self, val):
# Strength is not settable in this prox
pass
def _build_cpp_prox(self, dtype_or_object_with_dtype):
self.dtype = self._extract_dtype(dtype_or_object_with_dtype)
prox_class = self._get_typed_class(dtype_or_object_with_dtype,
dtype_map)
if self.range is None:
return prox_class(0., self.positive)
else:
return prox_class(0., self.range[0], self.range[1], self.positive)
return None