Source code for tick.prox.prox_l1w
# License: BSD 3 clause
# -*- coding: utf8 -*-
import numpy as np
from .base import Prox
from .build.prox import ProxL1wDouble as _ProxL1wDouble
from .build.prox import ProxL1wFloat as _ProxL1wFloat
__author__ = 'Stephane Gaiffas'
dtype_map = {
np.dtype("float64"): _ProxL1wDouble,
np.dtype("float32"): _ProxL1wFloat
}
# TODO: if we set a weights vector with length != end - start ???
[docs]class ProxL1w(Prox):
"""Proximal operator of the weighted L1 norm (weighted
soft-thresholding)
Parameters
----------
strength : `float`
Level of L1 penalization
weights : `numpy.ndarray`, shape=(n_coeffs,)
The weights to be used in the L1 penalization
range : `tuple` of two `int`, default=`None`
Range on which the prox is applied. If `None` then the prox is
applied on the whole vector
positive : `bool`, default=`False`
If True, apply L1 penalization together with a projection
onto the set of vectors with non-negative entries
Attributes
----------
dtype : `{'float64', 'float32'}`
Type of the arrays used.
"""
_attrinfos = {
"strength": {
"writable": True,
"cpp_setter": "set_strength"
},
"weights": {
"writable": True,
"cpp_setter": "set_weights"
},
"positive": {
"writable": True,
"cpp_setter": "set_positive"
}
}
[docs] def __init__(self, strength: float, weights: np.ndarray,
range: tuple = None, positive: bool = False):
Prox.__init__(self, range)
self.positive = positive
self.strength = strength
self.weights = weights
self._prox = self._build_cpp_prox("float64")
def _call(self, coeffs: np.ndarray, step: object, out: np.ndarray):
self._prox.call(coeffs, step, out)
[docs] def value(self, coeffs: np.ndarray):
"""
Returns the value of the penalization at ``coeffs``
Parameters
----------
coeffs : `numpy.ndarray`, shape=(n_coeffs,)
The value of the penalization is computed at this point
Returns
-------
output : `float`
Value of the penalization at ``coeffs``
"""
return self._prox.value(coeffs)
def _as_dict(self):
dd = Prox._as_dict(self)
del dd["weights"]
return dd
def _build_cpp_prox(self, dtype_or_object_with_dtype):
self.dtype = self._extract_dtype(dtype_or_object_with_dtype)
prox_class = self._get_typed_class(dtype_or_object_with_dtype,
dtype_map)
weights = self.weights.astype(self.dtype)
if self.range is None:
return_prox = prox_class(self.strength, weights, self.positive)
else:
start, end = self.range
if (end - start) != self.weights.shape[0]:
raise ValueError("Size of ``weights`` does not match "
"the given ``range``")
return_prox = prox_class(self.strength, weights, self.range[0],
self.range[1], self.positive)
return_prox.weights = weights
return return_prox