Skip to content

Commit

Permalink
WIP: make functionals and prox ops work for product spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Mar 13, 2019
1 parent e32d457 commit a3c4961
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 56 deletions.
28 changes: 20 additions & 8 deletions odl/solvers/functional/default_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def __init__(self):

def _call(self, x):
"""Apply the gradient operator to the given point."""
return np.sign(x)
if isinstance(self.domain, ProductSpace):
return self.domain.apply(np.sign, x)
else:
return np.sign(x)

def derivative(self, x):
"""Derivative is a.e. zero."""
Expand Down Expand Up @@ -1127,14 +1130,23 @@ def _call(self, x):
import scipy.special

if self.prior is None:
tmp = self.domain.inner(self.domain.one(), x - 1 - np.log(x))
if isinstance(self.domain, ProductSpace):
log_x = self.domain.apply(np.log, x)
else:
log_x = np.log(x)
tmp = self.domain.inner(self.domain.one(), x - 1 - log_x)

else:
tmp = self.domain.inner(
self.domain.one(),
x - self.prior + scipy.special.xlogy(
self.prior, self.prior / x
),
)
g = self.prior
if isinstance(self.domain, ProductSpace):
xlogy = self.domain.apply2(
lambda v, i: scipy.special.xlogy(g[i], g[i] / v), x
)
else:
xlogy = scipy.special.xlogy(g, g / x)

tmp = self.domain.inner(self.domain.one(), x - g + xlogy)

if np.isnan(tmp):
# In this case, some element was less than or equal to zero
return np.inf
Expand Down
65 changes: 23 additions & 42 deletions odl/solvers/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

from __future__ import print_function, division, absolute_import
from __future__ import absolute_import, division, print_function

import numpy as np

from odl.operator.default_ops import (
ConstantOperator, IdentityOperator, InnerProductOperator)
from odl.operator.operator import (
Operator, OperatorComp, OperatorLeftScalarMult, OperatorRightScalarMult,
OperatorRightVectorMult, OperatorSum, OperatorPointwiseProduct)
from odl.operator.default_ops import (IdentityOperator, ConstantOperator)
from odl.solvers.nonsmooth import (proximal_arg_scaling, proximal_translation,
proximal_quadratic_perturbation,
proximal_const_func, proximal_convex_conj)
from odl.util import signature_string, indent

Operator, OperatorComp, OperatorLeftScalarMult, OperatorPointwiseProduct,
OperatorRightScalarMult, OperatorRightVectorMult, OperatorSum)
from odl.solvers.nonsmooth import (
proximal_arg_scaling, proximal_const_func, proximal_convex_conj,
proximal_quadratic_perturbation, proximal_translation)
from odl.util import indent, signature_string

__all__ = ('Functional', 'FunctionalLeftScalarMult',
'FunctionalRightScalarMult', 'FunctionalComp',
Expand Down Expand Up @@ -204,7 +205,7 @@ def derivative(self, point):
-------
derivative : `Operator`
"""
return self.gradient(point).T
return InnerProductOperator(self.domain, self.gradient(point))

def translated(self, shift):
"""Return a translation of the functional.
Expand Down Expand Up @@ -1399,33 +1400,18 @@ def __init__(self, functional, point, subgrad):
raise TypeError('`functional` {} not an instance of ``Functional``'
''.format(functional))
self.__functional = functional

if point not in functional.domain:
raise ValueError('`point` {} is not in `functional.domain` {}'
''.format(point, functional.domain))
self.__point = point

if subgrad not in functional.domain:
raise TypeError(
'`subgrad` must be an element in `functional.domain`, got '
'{}'.format(subgrad))
self.__subgrad = subgrad

self.__constant = (
-functional(point)
+ functional.domain.inner(subgrad, point)
)

space = functional.domain
self.__point = space.element(point)
self.__subgrad = space.element(subgrad)
self.__constant = -functional(point) + space.inner(subgrad, point)
self.__bregman_dist = FunctionalQuadraticPerturb(
functional, linear_term=-subgrad, constant=self.__constant)

grad_lipschitz = (
functional.grad_lipschitz + functional.domain.norm(subgrad)
functional, linear_term=-subgrad, constant=self.__constant
)
grad_lipschitz = functional.grad_lipschitz + space.norm(subgrad)

super(BregmanDistance, self).__init__(
space=functional.domain, linear=False,
grad_lipschitz=grad_lipschitz)
space, linear=False, grad_lipschitz=grad_lipschitz
)

@property
def functional(self):
Expand Down Expand Up @@ -1459,15 +1445,10 @@ def proximal(self):
@property
def gradient(self):
"""Gradient operator of the functional."""
try:
op_to_return = self.functional.gradient
except NotImplementedError:
raise NotImplementedError(
'`self.functional.gradient` is not implemented for '
'`self.functional` {}'.format(self.functional))

op_to_return = op_to_return - ConstantOperator(self.subgrad)
return op_to_return
return (
self.functional.gradient
- ConstantOperator(self.domain, self.subgrad)
)

def __repr__(self):
'''Return ``repr(self)``.'''
Expand Down
12 changes: 6 additions & 6 deletions odl/solvers/nonsmooth/proximal_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
Foundations and Trends in Optimization, 1 (2014), pp 127-239.
"""

from __future__ import print_function, division, absolute_import
from __future__ import absolute_import, division, print_function

import numpy as np

from odl.operator import (
Operator, IdentityOperator, ConstantOperator, DiagonalOperator,
PointwiseNorm, MultiplyOperator)
ConstantOperator, DiagonalOperator, IdentityOperator, MultiplyOperator,
Operator, PointwiseNorm)
from odl.space import ProductSpace


__all__ = ('combine_proximals', 'proximal_convex_conj', 'proximal_translation',
'proximal_arg_scaling', 'proximal_quadratic_perturbation',
'proximal_composition', 'proximal_const_func',
Expand Down Expand Up @@ -799,7 +799,7 @@ def _call(self, x, out):
if step < 1.0:
self.range.lincomb(1 - step, x, out=out)
else:
out[:] = 0
self.range.lincomb(0, out, out=out)

else:
x_norm = self.domain.norm(x - g) * (1 + eps)
Expand All @@ -811,7 +811,7 @@ def _call(self, x, out):
if step < 1.0:
self.range.lincomb(1 - step, x, step, g, out=out)
else:
out[:] = g
self.range.lincomb(1, g, out=out)

return ProximalL2

Expand Down

0 comments on commit a3c4961

Please sign in to comment.