From 5d725b32cf9221869ea25d7806a72ae58f0c98b2 Mon Sep 17 00:00:00 2001
From: Samuel Farrens
Date: Thu, 16 Mar 2023 18:23:09 +0100
Subject: [PATCH 1/3] refactored code and ci
---
.github/workflows/cd-build.yml | 85 --------
.github/workflows/cd.yml | 75 +++++++
.github/workflows/ci-build.yml | 118 -----------
.github/workflows/ci.yml | 51 +++++
.github/workflows/test-docs.yml | 37 ++++
.github/workflows/test-python-versions.yml | 33 +++
.pylintrc | 2 -
.pyup.yml | 14 --
MANIFEST.in | 4 +-
develop.txt | 12 --
docs/requirements.txt | 9 -
docs/source/conf.py | 190 +++++-------------
docs/source/notebooks.rst | 8 -
docs/source/toc.rst | 1 -
modopt/__init__.py | 17 +-
modopt/base/__init__.py | 2 +-
modopt/base/backend.py | 58 +++---
modopt/base/np_adjust.py | 6 +-
modopt/base/observable.py | 12 +-
modopt/base/transform.py | 43 ++--
modopt/base/types.py | 27 ++-
modopt/base/wrappers.py | 9 +-
modopt/examples/conftest.py | 4 +-
.../example_lasso_forward_backward.py | 12 +-
modopt/interface/__init__.py | 2 +-
modopt/interface/errors.py | 28 ++-
modopt/interface/log.py | 16 +-
modopt/math/__init__.py | 2 +-
modopt/math/convolve.py | 32 +--
modopt/math/matrix.py | 33 ++-
modopt/math/metrics.py | 14 +-
modopt/opt/__init__.py | 2 +-
modopt/opt/algorithms/__init__.py | 25 ++-
modopt/opt/algorithms/base.py | 57 +++---
modopt/opt/algorithms/forward_backward.py | 153 +++++++-------
modopt/opt/algorithms/gradient_descent.py | 23 +--
modopt/opt/algorithms/primal_dual.py | 45 ++---
modopt/opt/cost.py | 29 ++-
modopt/opt/gradient.py | 12 +-
modopt/opt/linear.py | 23 +--
modopt/opt/proximity.py | 155 +++++++-------
modopt/opt/reweight.py | 5 +-
modopt/plot/__init__.py | 2 +-
modopt/plot/cost_plot.py | 16 +-
modopt/signal/__init__.py | 2 +-
modopt/signal/filter.py | 2 +-
modopt/signal/noise.py | 17 +-
modopt/signal/positivity.py | 8 +-
modopt/signal/svd.py | 60 +++---
modopt/signal/validation.py | 4 +-
modopt/signal/wavelet.py | 53 +++--
modopt/tests/test_signal.py | 7 +-
notebooks/.gitkeep | 0
pyproject.toml | 96 +++++++++
requirements.txt | 1 -
setup.cfg | 95 ---------
setup.py | 73 -------
57 files changed, 812 insertions(+), 1109 deletions(-)
delete mode 100644 .github/workflows/cd-build.yml
create mode 100644 .github/workflows/cd.yml
delete mode 100644 .github/workflows/ci-build.yml
create mode 100644 .github/workflows/ci.yml
create mode 100644 .github/workflows/test-docs.yml
create mode 100644 .github/workflows/test-python-versions.yml
delete mode 100644 .pylintrc
delete mode 100644 .pyup.yml
delete mode 100644 develop.txt
delete mode 100644 docs/requirements.txt
delete mode 100644 docs/source/notebooks.rst
delete mode 100644 notebooks/.gitkeep
create mode 100644 pyproject.toml
delete mode 100644 setup.cfg
delete mode 100644 setup.py
diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml
deleted file mode 100644
index fca9feb1..00000000
--- a/.github/workflows/cd-build.yml
+++ /dev/null
@@ -1,85 +0,0 @@
-name: CD
-
-on:
- push:
- branches:
- - master
- - main
-
-jobs:
-
- coverage:
- name: Deploy Coverage Results
- runs-on: ubuntu-latest
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python 3.8
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: 3.8
- auto-activate-base: false
-
- - name: Install dependencies
- shell: bash -l {0}
- run: |
- python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install twine
- python -m pip install .
-
- - name: Run Tests
- shell: bash -l {0}
- run: |
- python setup.py test
-
- - name: Check distribution
- shell: bash -l {0}
- run: |
- python setup.py sdist
- twine check dist/*
-
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v1
- with:
- token: ${{ secrets.CODECOV_TOKEN }}
- file: coverage.xml
- flags: unittests
-
- api:
- name: Deploy API Documentation
- needs: coverage
- runs-on: ubuntu-latest
- if: success()
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python 3.8
- uses: conda-incubator/setup-miniconda@v2
- with:
- python-version: "3.8"
-
- - name: Install dependencies
- shell: bash -l {0}
- run: |
- conda install -c conda-forge pandoc
- python -m pip install --upgrade pip
- python -m pip install -r docs/requirements.txt
- python -m pip install .
-
- - name: Build API documentation
- shell: bash -l {0}
- run: |
- sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
- sphinx-build -E docs/source docs/_build
-
- - name: Deploy API documentation
- uses: peaceiris/actions-gh-pages@v3.5.9
- with:
- github_token: ${{ secrets.GITHUB_TOKEN }}
- publish_dir: docs/_build
diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml
new file mode 100644
index 00000000..467b410f
--- /dev/null
+++ b/.github/workflows/cd.yml
@@ -0,0 +1,75 @@
+name: CD
+
+on:
+ push:
+ branches:
+ - master
+ - main
+
+jobs:
+
+ coverage:
+ name: Deploy Coverage Results
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.11
+
+ - name: Check Python version
+ run: python --version
+
+ - name: Install package
+ run: python -m pip install ".[release,test]"
+
+ - name: Run tests
+ run: pytest -n auto
+
+ - name: Check distribution
+ run: |
+ python -m build
+ python -m twine check dist/*
+
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v1
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ file: coverage.xml
+ flags: unittests
+
+ docs:
+ name: Deploy API Documentation
+ needs: coverage
+ runs-on: ubuntu-latest
+ if: success()
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.11
+
+ - name: Check Python version
+ run: python --version
+
+ - name: Install package
+ run: python -m pip install ".[docs]"
+
+ - name: Build API documentation
+ run: |
+ sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
+ sphinx-build -E docs/source docs/_build
+
+ - name: Deploy API documentation
+ uses: peaceiris/actions-gh-pages@v3.5.9
+ with:
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ publish_dir: docs/_build
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
deleted file mode 100644
index c4ba28a0..00000000
--- a/.github/workflows/ci-build.yml
+++ /dev/null
@@ -1,118 +0,0 @@
-name: CI
-
-on:
- pull_request:
- branches:
- - develop
- - master
- - main
-
-jobs:
- test-full:
- name: Full Test Suite
- runs-on: ${{ matrix.os }}
-
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest]
- python-version: ["3.10"]
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Check Conda
- shell: bash -l {0}
- run: |
- conda info
- conda list
- python --version
-
- - name: Install Dependencies
- shell: bash -l {0}
- run: |
- python --version
- python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install -r docs/requirements.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install tensorflow>=2.4.1
- python -m pip install twine
- python -m pip install .
-
- - name: Run Tests
- shell: bash -l {0}
- run: |
- export PATH=/usr/share/miniconda/bin:$PATH
- pytest -n 2
-
- - name: Save Test Results
- if: always()
- uses: actions/upload-artifact@v2
- with:
- name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }}
- path: pytest.xml
-
- - name: Check Distribution
- shell: bash -l {0}
- run: |
- python setup.py sdist
- twine check dist/*
-
- - name: Check API Documentation build
- shell: bash -l {0}
- run: |
- conda install -c conda-forge pandoc
- sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
- sphinx-build -b doctest -E docs/source docs/_build
-
- - name: Upload Coverage to Codecov
- uses: codecov/codecov-action@v1
- with:
- token: ${{ secrets.CODECOV_TOKEN }}
- file: coverage.xml
- flags: unittests
-
- test-basic:
- name: Basic Test Suite
- runs-on: ${{ matrix.os }}
-
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest]
- python-version: ["3.7", "3.8", "3.9"]
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Install Dependencies
- shell: bash -l {0}
- run: |
- python --version
- python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install .
-
- - name: Run Tests
- shell: bash -l {0}
- run: |
- export PATH=/usr/share/miniconda/bin:$PATH
- pytest -n 2
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 00000000..3794f898
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,51 @@
+name: CI
+
+on:
+ pull_request:
+ branches:
+ - develop
+ - master
+ - main
+
+jobs:
+ test-full:
+ name: Run CI Tests
+ runs-on: ${{ matrix.os }}
+
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest]
+ python-version: ["3.11"]
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Check Python version
+ run: python --version
+
+ - name: Install package
+ run: python -m pip install ".[extra,test]"
+
+ - name: Run tests
+ run: python -m pytest -n auto
+
+ - name: Save test results
+ if: always()
+ uses: actions/upload-artifact@v2
+ with:
+ name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }}
+ path: pytest.xml
+
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v1
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ file: coverage.xml
+ flags: unittests
diff --git a/.github/workflows/test-docs.yml b/.github/workflows/test-docs.yml
new file mode 100644
index 00000000..49d1d9fb
--- /dev/null
+++ b/.github/workflows/test-docs.yml
@@ -0,0 +1,37 @@
+name: Test Docs
+
+on: [workflow_dispatch]
+
+jobs:
+
+ build-docs:
+ name: Test Docs Build
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.11
+
+ - name: Check Python version
+ run: python --version
+
+ - name: Install package
+ run: python -m pip install ".[docs]"
+
+ - name: Test API documentation build
+ run: |
+ sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
+ sphinx-build -b doctest -E docs/source docs/_build
+
+ - name: Archive API documentation build
+ uses: actions/upload-artifact@v2
+ with:
+ name: api-docs
+ retention-days: 14
+ path: |
+ docs/_build
\ No newline at end of file
diff --git a/.github/workflows/test-python-versions.yml b/.github/workflows/test-python-versions.yml
new file mode 100644
index 00000000..13d892a4
--- /dev/null
+++ b/.github/workflows/test-python-versions.yml
@@ -0,0 +1,33 @@
+name: Test Python Versions
+
+on: [workflow_dispatch]
+
+jobs:
+
+ test-basic:
+ name: Python Test Suite
+ runs-on: ${{ matrix.os }}
+
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Check Python version
+ run: python --version
+
+ - name: Install package
+ run: python -m pip install ".[test]"
+
+ - name: Run tests
+ run: pytest -n auto
diff --git a/.pylintrc b/.pylintrc
deleted file mode 100644
index 3ac9aef9..00000000
--- a/.pylintrc
+++ /dev/null
@@ -1,2 +0,0 @@
-[MASTER]
-ignore-patterns=**/docs/**/*.py
diff --git a/.pyup.yml b/.pyup.yml
deleted file mode 100644
index 8fdac7ff..00000000
--- a/.pyup.yml
+++ /dev/null
@@ -1,14 +0,0 @@
-# autogenerated pyup.io config file
-# see https://pyup.io/docs/configuration/ for all available options
-
-schedule: ''
-update: all
-label_prs: update
-assignees: sfarrens
-requirements:
- - requirements.txt:
- pin: False
- - develop.txt:
- pin: False
- - docs/requirements.txt:
- pin: True
diff --git a/MANIFEST.in b/MANIFEST.in
index 9a2f374e..c0d0aaf4 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,5 @@
+include CODE_OF_CONDUCT.md
+include CONTRIBUTING.md
include requirements.txt
-include develop.txt
-include docs/requirements.txt
include README.rst
include LICENSE.txt
diff --git a/develop.txt b/develop.txt
deleted file mode 100644
index 6ff665eb..00000000
--- a/develop.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-coverage>=5.5
-pytest>=6.2.2
-pytest-raises>=0.10
-pytest-cases>= 3.6
-pytest-xdist>= 3.0.1
-pytest-cov>=2.11.1
-pytest-emoji>=0.2.0
-pydocstyle==6.1.1
-pytest-pydocstyle>=2.2.0
-black
-isort
-pytest-black
diff --git a/docs/requirements.txt b/docs/requirements.txt
deleted file mode 100644
index c9e29c88..00000000
--- a/docs/requirements.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-jupyter==1.0.0
-myst-parser==0.16.1
-nbsphinx==0.8.7
-nbsphinx-link==1.3.0
-numpydoc==1.1.0
-sphinx==4.3.1
-sphinxcontrib-bibtex==2.4.1
-sphinxawesome-theme==3.2.1
-sphinx-gallery==0.11.1
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 46564b9f..b3d35f0e 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Python Template sphinx config
# Import relevant modules
@@ -9,56 +8,54 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('../..'))
+sys.path.insert(0, os.path.abspath("../.."))
# -- General configuration ------------------------------------------------
# General information about the project.
-project = 'modopt'
+project = "modopt"
mdata = metadata(project)
-author = mdata['Author']
-version = mdata['Version']
-copyright = '2020, {}'.format(author)
-gh_user = 'sfarrens'
+author = mdata["Author-email"]
+version = mdata["Version"]
+copyright = "2020, {}".format(author)
+gh_user = "sfarrens"
# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = '3.3'
+needs_sphinx = "3.3"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.coverage',
- 'sphinx.ext.doctest',
- 'sphinx.ext.ifconfig',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon',
- 'sphinx.ext.todo',
- 'sphinx.ext.viewcode',
- 'sphinxawesome_theme',
- 'sphinxcontrib.bibtex',
- 'myst_parser',
- 'nbsphinx',
- 'nbsphinx_link',
- 'numpydoc',
- "sphinx_gallery.gen_gallery"
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.coverage",
+ "sphinx.ext.doctest",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.todo",
+ "sphinx.ext.viewcode",
+ "sphinxawesome_theme",
+ "sphinxcontrib.bibtex",
+ "sphinx_gallery.gen_gallery",
+ "myst_parser",
+ "numpydoc",
]
# Include module names for objects
add_module_names = False
# Set class documentation standard.
-autoclass_content = 'class'
+autoclass_content = "class"
# Audodoc options
autodoc_default_options = {
- 'member-order': 'bysource',
- 'private-members': True,
- 'show-inheritance': True
+ "member-order": "bysource",
+ "private-members": True,
+ "show-inheritance": True,
}
# Generate summaries
@@ -69,17 +66,17 @@
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
-source_suffix = ['.rst', '.md']
+source_suffix = [".rst", ".md"]
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
show_authors = True
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'default'
+pygments_style = "default"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
@@ -88,7 +85,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'sphinxawesome_theme'
+html_theme = "sphinxawesome_theme"
# html_theme = 'sphinx_book_theme'
# Theme options are theme-specific and customize the look and feel of a theme
@@ -101,11 +98,10 @@
"breadcrumbs_separator": "/",
"show_prev_next": True,
"show_scrolltop": True,
-
}
html_collapsible_definitions = True
html_awesome_headerlinks = True
-html_logo = 'modopt_logo.jpg'
+html_logo = "modopt_logo.png"
html_permalinks_icon = (
'
'''
-)
-
-nbsphinx_prolog = nb_header_pt1 + nb_header_pt2
-
# -- Intersphinx Mapping ----------------------------------------------
# Refer to the package libraries for type definitions
intersphinx_mapping = {
- 'python': ('http://docs.python.org/3', None),
- 'numpy': ('https://numpy.org/doc/stable/', None),
- 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None),
- 'progressbar': ('https://progressbar-2.readthedocs.io/en/latest/', None),
- 'matplotlib': ('https://matplotlib.org', None),
- 'astropy': ('http://docs.astropy.org/en/latest/', None),
- 'cupy': ('https://docs-cupy.chainer.org/en/stable/', None),
- 'torch': ('https://pytorch.org/docs/stable/', None),
- 'sklearn': (
- 'http://scikit-learn.org/stable',
- (None, './_intersphinx/sklearn-objects.inv')
+ "python": ("http://docs.python.org/3", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None),
+ "matplotlib": ("https://matplotlib.org", None),
+ "astropy": ("http://docs.astropy.org/en/latest/", None),
+ "cupy": ("https://docs-cupy.chainer.org/en/stable/", None),
+ "torch": ("https://pytorch.org/docs/stable/", None),
+ "sklearn": (
+ "http://scikit-learn.org/stable",
+ (None, "./_intersphinx/sklearn-objects.inv"),
),
- 'tensorflow': (
- 'https://www.tensorflow.org/api_docs/python',
+ "tensorflow": (
+ "https://www.tensorflow.org/api_docs/python",
(
- 'https://github.com/GPflow/tensorflow-intersphinx/'
- + 'raw/master/tf2_py_objects.inv')
- )
-
+ "https://github.com/GPflow/tensorflow-intersphinx/"
+ + "raw/master/tf2_py_objects.inv"
+ ),
+ ),
}
# -- BibTeX Setting ----------------------------------------------
-bibtex_bibfiles = ['refs.bib', 'my_ref.bib']
-bibtex_default_style = 'alpha'
+bibtex_bibfiles = ["refs.bib", "my_ref.bib"]
+bibtex_default_style = "alpha"
diff --git a/docs/source/notebooks.rst b/docs/source/notebooks.rst
deleted file mode 100644
index 86c3a4b0..00000000
--- a/docs/source/notebooks.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-Notebooks
-=========
-
-List of Notebooks
------------------
-
-.. toctree::
- :maxdepth: 4
diff --git a/docs/source/toc.rst b/docs/source/toc.rst
index ef5753f5..c004044e 100644
--- a/docs/source/toc.rst
+++ b/docs/source/toc.rst
@@ -24,7 +24,6 @@
:caption: Examples
plugin_example
- notebooks
auto_examples/index
.. toctree::
diff --git a/modopt/__init__.py b/modopt/__init__.py
index 2c06c1db..02606b31 100644
--- a/modopt/__init__.py
+++ b/modopt/__init__.py
@@ -1,24 +1,9 @@
-# -*- coding: utf-8 -*-
-
"""MODOPT PACKAGE.
ModOpt is a series of Modular Optimisation tools for solving inverse problems.
"""
-from warnings import warn
-
-from importlib_metadata import version
-
from modopt.base import *
-try:
- _version = version('modopt')
-except Exception: # pragma: no cover
- _version = 'Unkown'
- warn(
- 'Could not extract package metadata. Make sure the package is '
- + 'correctly installed.',
- )
-
-__version__ = _version
+__version__ = "1.6.1"
diff --git a/modopt/base/__init__.py b/modopt/base/__init__.py
index 1c0c8b2c..3946e04c 100644
--- a/modopt/base/__init__.py
+++ b/modopt/base/__init__.py
@@ -9,4 +9,4 @@
"""
-__all__ = ['np_adjust', 'transform', 'types', 'wrappers', 'observable']
+__all__ = ["np_adjust", "transform", "types", "wrappers", "observable"]
diff --git a/modopt/base/backend.py b/modopt/base/backend.py
index 1f4e9a72..fd933ebb 100644
--- a/modopt/base/backend.py
+++ b/modopt/base/backend.py
@@ -26,22 +26,24 @@
# Handle the compatibility with variable
LIBRARIES = {
- 'cupy': None,
- 'tensorflow': None,
- 'numpy': np,
+ "cupy": None,
+ "tensorflow": None,
+ "numpy": np,
}
-if util.find_spec('cupy') is not None:
+if util.find_spec("cupy") is not None:
try:
import cupy as cp
- LIBRARIES['cupy'] = cp
+
+ LIBRARIES["cupy"] = cp
except ImportError:
pass
-if util.find_spec('tensorflow') is not None:
+if util.find_spec("tensorflow") is not None:
try:
from tensorflow.experimental import numpy as tnp
- LIBRARIES['tensorflow'] = tnp
+
+ LIBRARIES["tensorflow"] = tnp
except ImportError:
pass
@@ -66,12 +68,12 @@ def get_backend(backend):
"""
if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None:
msg = (
- '{0} backend not possible, please ensure that '
- + 'the optional libraries are installed.\n'
- + 'Reverting to numpy.'
+ "{0} backend not possible, please ensure that "
+ + "the optional libraries are installed.\n"
+ + "Reverting to numpy."
)
warn(msg.format(backend))
- backend = 'numpy'
+ backend = "numpy"
return LIBRARIES[backend], backend
@@ -92,16 +94,16 @@ def get_array_module(input_data):
The numpy or cupy module
"""
- if LIBRARIES['tensorflow'] is not None:
- if isinstance(input_data, LIBRARIES['tensorflow'].ndarray):
- return LIBRARIES['tensorflow']
- if LIBRARIES['cupy'] is not None:
- if isinstance(input_data, LIBRARIES['cupy'].ndarray):
- return LIBRARIES['cupy']
+ if LIBRARIES["tensorflow"] is not None:
+ if isinstance(input_data, LIBRARIES["tensorflow"].ndarray):
+ return LIBRARIES["tensorflow"]
+ if LIBRARIES["cupy"] is not None:
+ if isinstance(input_data, LIBRARIES["cupy"].ndarray):
+ return LIBRARIES["cupy"]
return np
-def change_backend(input_data, backend='cupy'):
+def change_backend(input_data, backend="cupy"):
"""Move data to device.
This method changes the backend of an array. This can be used to copy data
@@ -151,13 +153,13 @@ def move_to_cpu(input_data):
"""
xp = get_array_module(input_data)
- if xp == LIBRARIES['numpy']:
+ if xp == LIBRARIES["numpy"]:
return input_data
- elif xp == LIBRARIES['cupy']:
+ elif xp == LIBRARIES["cupy"]:
return input_data.get()
- elif xp == LIBRARIES['tensorflow']:
+ elif xp == LIBRARIES["tensorflow"]:
return input_data.data.numpy()
- raise ValueError('Cannot identify the array type.')
+ raise ValueError("Cannot identify the array type.")
def convert_to_tensor(input_data):
@@ -184,9 +186,9 @@ def convert_to_tensor(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
xp = get_array_module(input_data)
@@ -220,9 +222,9 @@ def convert_to_cupy_array(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
if input_data.is_cuda:
diff --git a/modopt/base/np_adjust.py b/modopt/base/np_adjust.py
index 6d290e43..31a785f5 100644
--- a/modopt/base/np_adjust.py
+++ b/modopt/base/np_adjust.py
@@ -154,8 +154,8 @@ def pad2d(input_data, padding):
padding = np.array(padding)
elif not isinstance(padding, np.ndarray):
raise ValueError(
- 'Padding must be an integer or a tuple (or list, np.ndarray) '
- + 'of itegers',
+ "Padding must be an integer or a tuple (or list, np.ndarray) "
+ + "of itegers",
)
if padding.size == 1:
@@ -164,7 +164,7 @@ def pad2d(input_data, padding):
pad_x = (padding[0], padding[0])
pad_y = (padding[1], padding[1])
- return np.pad(input_data, (pad_x, pad_y), 'constant')
+ return np.pad(input_data, (pad_x, pad_y), "constant")
def ftr(input_data):
diff --git a/modopt/base/observable.py b/modopt/base/observable.py
index 6471ba58..4444af41 100644
--- a/modopt/base/observable.py
+++ b/modopt/base/observable.py
@@ -33,7 +33,6 @@ class Observable(object):
"""
def __init__(self, signals):
-
# Define class parameters
self._allowed_signals = []
self._observers = {}
@@ -215,7 +214,6 @@ def __init__(
wind=6,
eps=1.0e-3,
):
-
self.name = name
self.metric = metric
self.mapping = mapping
@@ -264,9 +262,7 @@ def is_converge(self):
mid_idx = -(self.wind // 2)
old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean()
current_mean = np.array(self.list_cv_values[mid_idx:]).mean()
- normalize_residual_metrics = (
- np.abs(old_mean - current_mean) / np.abs(old_mean)
- )
+ normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean)
self.converge_flag = normalize_residual_metrics < self.eps
def retrieve_metrics(self):
@@ -287,7 +283,7 @@ def retrieve_metrics(self):
time_val -= time_val[0]
return {
- 'time': time_val,
- 'index': self.list_iters,
- 'values': self.list_cv_values,
+ "time": time_val,
+ "index": self.list_iters,
+ "values": self.list_cv_values,
}
diff --git a/modopt/base/transform.py b/modopt/base/transform.py
index 07ce846f..fedd5efb 100644
--- a/modopt/base/transform.py
+++ b/modopt/base/transform.py
@@ -53,18 +53,17 @@ def cube2map(data_cube, layout):
"""
if data_cube.ndim != 3:
- raise ValueError('The input data must have 3 dimensions.')
+ raise ValueError("The input data must have 3 dimensions.")
if data_cube.shape[0] != np.prod(layout):
raise ValueError(
- 'The desired layout must match the number of input '
- + 'data layers.',
+ "The desired layout must match the number of input " + "data layers.",
)
- res = ([
+ res = [
np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))])
for elem in range(layout[0])
- ])
+ ]
return np.vstack(res)
@@ -118,20 +117,24 @@ def map2cube(data_map, layout):
"""
if np.all(np.array(data_map.shape) % np.array(layout)):
raise ValueError(
- 'The desired layout must be a multiple of the number '
- + 'pixels in the data map.',
+ "The desired layout must be a multiple of the number "
+ + "pixels in the data map.",
)
d_shape = np.array(data_map.shape) // np.array(layout)
- return np.array([
- data_map[(
- slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
- slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
- )]
- for i_elem in range(layout[0])
- for j_elem in range(layout[1])
- ])
+ return np.array(
+ [
+ data_map[
+ (
+ slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
+ slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
+ )
+ ]
+ for i_elem in range(layout[0])
+ for j_elem in range(layout[1])
+ ]
+ )
def map2matrix(data_map, layout):
@@ -186,9 +189,9 @@ def map2matrix(data_map, layout):
image_shape * (i_elem % layout[1] + 1),
)
data_matrix.append(
- (
- data_map[lower[0]:upper[0], lower[1]:upper[1]]
- ).reshape(image_shape ** 2),
+ (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape(
+ image_shape**2
+ ),
)
return np.array(data_matrix).T
@@ -232,7 +235,7 @@ def matrix2map(data_matrix, map_shape):
# Get the shape and layout of the images
image_shape = np.sqrt(data_matrix.shape[0]).astype(int)
- layout = np.array(map_shape // np.repeat(image_shape, 2), dtype='int')
+ layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int")
# Map objects from matrix
data_map = np.zeros(map_shape)
@@ -248,7 +251,7 @@ def matrix2map(data_matrix, map_shape):
image_shape * (i_elem // layout[1] + 1),
image_shape * (i_elem % layout[1] + 1),
)
- data_map[lower[0]:upper[0], lower[1]:upper[1]] = temp[:, :, i_elem]
+ data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem]
return data_map.astype(int)
diff --git a/modopt/base/types.py b/modopt/base/types.py
index 88051675..af2184f5 100644
--- a/modopt/base/types.py
+++ b/modopt/base/types.py
@@ -44,7 +44,7 @@ def check_callable(input_obj, add_agrs=True):
"""
if not callable(input_obj):
- raise TypeError('The input object must be a callable function.')
+ raise TypeError("The input object must be a callable function.")
if add_agrs:
input_obj = add_args_kwargs(input_obj)
@@ -89,14 +89,13 @@ def check_float(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, int):
input_obj = float(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=float)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.floating))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.floating)
):
input_obj = input_obj.astype(float)
@@ -139,14 +138,13 @@ def check_int(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, float):
input_obj = int(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=int)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.integer))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.integer)
):
input_obj = input_obj.astype(int)
@@ -178,19 +176,18 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
"""
if not isinstance(input_obj, np.ndarray):
- raise TypeError('Input is not a numpy array.')
+ raise TypeError("Input is not a numpy array.")
- if (
- (not isinstance(dtype, type(None)))
- and (not np.issubdtype(input_obj.dtype, dtype))
+ if (not isinstance(dtype, type(None))) and (
+ not np.issubdtype(input_obj.dtype, dtype)
):
raise (
TypeError(
- 'The numpy array elements are not of type: {0}'.format(dtype),
+ "The numpy array elements are not of type: {0}".format(dtype),
),
)
if not writeable and verbose and input_obj.flags.writeable:
- warn('Making input data immutable.')
+ warn("Making input data immutable.")
input_obj.flags.writeable = writeable
diff --git a/modopt/base/wrappers.py b/modopt/base/wrappers.py
index baedb891..34d28553 100644
--- a/modopt/base/wrappers.py
+++ b/modopt/base/wrappers.py
@@ -29,18 +29,17 @@ def add_args_kwargs(func):
wrapper
"""
+
@wraps(func)
def wrapper(*args, **kwargs):
-
props = argspec(func)
# if 'args' not in props:
if isinstance(props[1], type(None)):
- args = args[:len(props[0])]
+ args = args[: len(props[0])]
- if (
- (not isinstance(props[2], type(None)))
- or (not isinstance(props[3], type(None)))
+ if (not isinstance(props[2], type(None))) or (
+ not isinstance(props[3], type(None))
):
return func(*args, **kwargs)
diff --git a/modopt/examples/conftest.py b/modopt/examples/conftest.py
index 73358679..9e845813 100644
--- a/modopt/examples/conftest.py
+++ b/modopt/examples/conftest.py
@@ -15,6 +15,7 @@
import runpy
import pytest
+
def pytest_collect_file(path, parent):
"""Pytest hook.
@@ -22,7 +23,7 @@ def pytest_collect_file(path, parent):
The new node needs to have the specified parent as parent.
"""
p = Path(path)
- if p.suffix == '.py' and 'example' in p.name:
+ if p.suffix == ".py" and "example" in p.name:
return Script.from_parent(parent, path=p, name=p.name)
@@ -33,6 +34,7 @@ def collect(self):
"""Collect the script as its own item."""
yield ScriptItem.from_parent(self, name=self.name)
+
class ScriptItem(pytest.Item):
"""Item script collected by pytest."""
diff --git a/modopt/examples/example_lasso_forward_backward.py b/modopt/examples/example_lasso_forward_backward.py
index 7f820000..3f7bd922 100644
--- a/modopt/examples/example_lasso_forward_backward.py
+++ b/modopt/examples/example_lasso_forward_backward.py
@@ -7,9 +7,11 @@
using the Forward-Backward Algorithm.
In this example we are going to use:
- - Modopt Operators (Linear, Gradient, Proximal)
- - Modopt implementation of solvers
- - Modopt Metric API.
+
+* Modopt Operators (Linear, Gradient, Proximal)
+* Modopt implementation of solvers
+* Modopt Metric API.
+
TODO: add reference to LASSO paper.
"""
@@ -76,7 +78,7 @@
prox=prox_op,
cost=cost_op_fista,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_fista.iterate()
@@ -115,7 +117,7 @@
prox=prox_op,
cost=cost_op_pogm,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_pogm.iterate()
diff --git a/modopt/interface/__init__.py b/modopt/interface/__init__.py
index f9439747..55904ca1 100644
--- a/modopt/interface/__init__.py
+++ b/modopt/interface/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['errors', 'log']
+__all__ = ["errors", "log"]
diff --git a/modopt/interface/errors.py b/modopt/interface/errors.py
index 0fbe7e71..eb4aa4ca 100644
--- a/modopt/interface/errors.py
+++ b/modopt/interface/errors.py
@@ -34,12 +34,12 @@ def warn(warn_string, log=None):
"""
if import_fail:
- warn_txt = 'WARNING'
+ warn_txt = "WARNING"
else:
- warn_txt = colored('WARNING', 'yellow')
+ warn_txt = colored("WARNING", "yellow")
# Print warning to stdout.
- sys.stderr.write('{0}: {1}\n'.format(warn_txt, warn_string))
+ sys.stderr.write("{0}: {1}\n".format(warn_txt, warn_string))
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
@@ -61,17 +61,17 @@ def catch_error(exception, log=None):
"""
if import_fail:
- err_txt = 'ERROR'
+ err_txt = "ERROR"
else:
- err_txt = colored('ERROR', 'red')
+ err_txt = colored("ERROR", "red")
# Print exception to stdout.
- stream_txt = '{0}: {1}\n'.format(err_txt, exception)
+ stream_txt = "{0}: {1}\n".format(err_txt, exception)
sys.stderr.write(stream_txt)
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- log_txt = 'ERROR: {0}\n'.format(exception)
+ log_txt = "ERROR: {0}\n".format(exception)
log.exception(log_txt)
@@ -91,11 +91,11 @@ def file_name_error(file_name):
If file name not specified or file not found
"""
- if file_name == '' or file_name[0][0] == '-':
- raise IOError('Input file name not specified.')
+ if file_name == "" or file_name[0][0] == "-":
+ raise IOError("Input file name not specified.")
elif not os.path.isfile(file_name):
- raise IOError('Input file name {0} not found!'.format(file_name))
+ raise IOError("Input file name {0} not found!".format(file_name))
def is_exe(fpath):
@@ -136,7 +136,7 @@ def is_executable(exe_name):
"""
if not isinstance(exe_name, str):
- raise TypeError('Executable name must be a string.')
+ raise TypeError("Executable name must be a string.")
fpath, fname = os.path.split(exe_name)
@@ -146,11 +146,9 @@ def is_executable(exe_name):
else:
res = any(
is_exe(os.path.join(path, exe_name))
- for path in os.environ['PATH'].split(os.pathsep)
+ for path in os.environ["PATH"].split(os.pathsep)
)
if not res:
- message = (
- '{0} does not appear to be a valid executable on this system.'
- )
+ message = "{0} does not appear to be a valid executable on this system."
raise IOError(message.format(exe_name))
diff --git a/modopt/interface/log.py b/modopt/interface/log.py
index 3b2fa77a..a02428d9 100644
--- a/modopt/interface/log.py
+++ b/modopt/interface/log.py
@@ -30,22 +30,22 @@ def set_up_log(filename, verbose=True):
"""
# Add file extension.
- filename = '{0}.log'.format(filename)
+ filename = "{0}.log".format(filename)
if verbose:
- print('Preparing log file:', filename)
+ print("Preparing log file:", filename)
# Capture warnings.
logging.captureWarnings(True)
# Set output format.
formatter = logging.Formatter(
- fmt='%(asctime)s %(message)s',
- datefmt='%d/%m/%Y %H:%M:%S',
+ fmt="%(asctime)s %(message)s",
+ datefmt="%d/%m/%Y %H:%M:%S",
)
# Create file handler.
- fh = logging.FileHandler(filename=filename, mode='w')
+ fh = logging.FileHandler(filename=filename, mode="w")
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
@@ -55,7 +55,7 @@ def set_up_log(filename, verbose=True):
log.addHandler(fh)
# Send opening message.
- log.info('The log file has been set-up.')
+ log.info("The log file has been set-up.")
return log
@@ -74,10 +74,10 @@ def close_log(log, verbose=True):
"""
if verbose:
- print('Closing log file:', log.name)
+ print("Closing log file:", log.name)
# Send closing message.
- log.info('The log file has been closed.')
+ log.info("The log file has been closed.")
# Remove all handlers from log.
for log_handler in log.handlers:
diff --git a/modopt/math/__init__.py b/modopt/math/__init__.py
index a22c0c98..8e92aa50 100644
--- a/modopt/math/__init__.py
+++ b/modopt/math/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['convolve', 'matrix', 'stats', 'metrics']
+__all__ = ["convolve", "matrix", "stats", "metrics"]
diff --git a/modopt/math/convolve.py b/modopt/math/convolve.py
index a4322ff2..528b2338 100644
--- a/modopt/math/convolve.py
+++ b/modopt/math/convolve.py
@@ -18,7 +18,7 @@
from astropy.convolution import convolve_fft
except ImportError: # pragma: no cover
import_astropy = False
- warn('astropy not found, will default to scipy for convolution')
+ warn("astropy not found, will default to scipy for convolution")
else:
import_astropy = True
try:
@@ -30,7 +30,7 @@
warn('Using pyFFTW "monkey patch" for scipy.fftpack')
-def convolve(input_data, kernel, method='scipy'):
+def convolve(input_data, kernel, method="scipy"):
"""Convolve data with kernel.
This method convolves the input data with a given kernel using FFT and
@@ -80,29 +80,29 @@ def convolve(input_data, kernel, method='scipy'):
"""
if input_data.ndim != kernel.ndim:
- raise ValueError('Data and kernel must have the same dimensions.')
+ raise ValueError("Data and kernel must have the same dimensions.")
- if method not in {'astropy', 'scipy'}:
+ if method not in {"astropy", "scipy"}:
raise ValueError('Invalid method. Options are "astropy" or "scipy".')
if not import_astropy: # pragma: no cover
- method = 'scipy'
+ method = "scipy"
- if method == 'astropy':
+ if method == "astropy":
return convolve_fft(
input_data,
kernel,
- boundary='wrap',
+ boundary="wrap",
crop=False,
- nan_treatment='fill',
+ nan_treatment="fill",
normalize_kernel=False,
)
- elif method == 'scipy':
- return scipy.signal.fftconvolve(input_data, kernel, mode='same')
+ elif method == "scipy":
+ return scipy.signal.fftconvolve(input_data, kernel, mode="same")
-def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
+def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"):
"""Convolve stack of data with stack of kernels.
This method convolves the input data with a given kernel using FFT and
@@ -156,7 +156,9 @@ def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
if rot_kernel:
kernel = rotate_stack(kernel)
- return np.array([
- convolve(data_i, kernel_i, method=method)
- for data_i, kernel_i in zip(input_data, kernel)
- ])
+ return np.array(
+ [
+ convolve(data_i, kernel_i, method=method)
+ for data_i, kernel_i in zip(input_data, kernel)
+ ]
+ )
diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py
index 8361531d..8a7c4eae 100644
--- a/modopt/math/matrix.py
+++ b/modopt/math/matrix.py
@@ -15,7 +15,7 @@
from modopt.base.backend import get_array_module, get_backend
-def gram_schmidt(matrix, return_opt='orthonormal'):
+def gram_schmidt(matrix, return_opt="orthonormal"):
r"""Gram-Schmit.
This method orthonormalizes the row vectors of the input matrix.
@@ -55,7 +55,7 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
"""
- if return_opt not in {'orthonormal', 'orthogonal', 'both'}:
+ if return_opt not in {"orthonormal", "orthogonal", "both"}:
raise ValueError(
'Invalid return_opt, options are: "orthonormal", "orthogonal" or '
+ '"both"',
@@ -65,7 +65,6 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
e_vec = []
for vector in matrix:
-
if u_vec:
u_now = vector - sum(project(u_i, vector) for u_i in u_vec)
else:
@@ -77,11 +76,11 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
u_vec = np.array(u_vec)
e_vec = np.array(e_vec)
- if return_opt == 'orthonormal':
+ if return_opt == "orthonormal":
return e_vec
- elif return_opt == 'orthogonal':
+ elif return_opt == "orthogonal":
return u_vec
- elif return_opt == 'both':
+ elif return_opt == "both":
return u_vec, e_vec
@@ -201,7 +200,7 @@ def rot_matrix(angle):
return np.around(
np.array(
[[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]],
- dtype='float',
+ dtype="float",
),
10,
)
@@ -243,16 +242,15 @@ def rotate(matrix, angle):
shape = np.array(matrix.shape)
if shape[0] != shape[1]:
- raise ValueError('Input matrix must be square.')
+ raise ValueError("Input matrix must be square.")
shift = (shape - 1) // 2
index = (
- np.array(list(product(*np.array([np.arange(sval) for sval in shape]))))
- - shift
+ np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - shift
)
- new_index = np.array(np.dot(index, rot_matrix(angle)), dtype='int') + shift
+ new_index = np.array(np.dot(index, rot_matrix(angle)), dtype="int") + shift
new_index[new_index >= shape[0]] -= shape[0]
return matrix[tuple(zip(new_index.T))].reshape(shape.T)
@@ -301,10 +299,9 @@ def __init__(
data_shape,
data_type=float,
auto_run=True,
- compute_backend='numpy',
+ compute_backend="numpy",
verbose=False,
):
-
self._operator = operator
self._data_shape = data_shape
self._data_type = data_type
@@ -363,18 +360,14 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
x_new /= x_new_norm
- if (xp.abs(x_new_norm - x_old_norm) < tolerance):
- message = (
- ' - Power Method converged after {0} iterations!'
- )
+ if xp.abs(x_new_norm - x_old_norm) < tolerance:
+ message = " - Power Method converged after {0} iterations!"
if self._verbose:
print(message.format(i_elem + 1))
break
elif i_elem == max_iter - 1 and self._verbose:
- message = (
- ' - Power Method did not converge after {0} iterations!'
- )
+ message = " - Power Method did not converge after {0} iterations!"
print(message.format(max_iter))
xp.copyto(x_old, x_new)
diff --git a/modopt/math/metrics.py b/modopt/math/metrics.py
index 1b870e23..db1ac4e2 100644
--- a/modopt/math/metrics.py
+++ b/modopt/math/metrics.py
@@ -71,15 +71,13 @@ def _preprocess_input(test, ref, mask=None):
The SNR
"""
- test = np.abs(np.copy(test)).astype('float64')
- ref = np.abs(np.copy(ref)).astype('float64')
+ test = np.abs(np.copy(test)).astype("float64")
+ ref = np.abs(np.copy(ref)).astype("float64")
test = min_max_normalize(test)
ref = min_max_normalize(ref)
if (not isinstance(mask, np.ndarray)) and (mask is not None):
- message = (
- 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
- )
+ message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
raise ValueError(message.format(mask))
if mask is None:
@@ -119,9 +117,9 @@ def ssim(test, ref, mask=None):
"""
if not import_skimage: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Image package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Image package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
test, ref, mask = _preprocess_input(test, ref, mask)
diff --git a/modopt/opt/__init__.py b/modopt/opt/__init__.py
index 2fd3d747..8b285bee 100644
--- a/modopt/opt/__init__.py
+++ b/modopt/opt/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['cost', 'gradient', 'linear', 'algorithms', 'proximity', 'reweight']
+__all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"]
diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py
index e0ac2572..26eef548 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/modopt/opt/algorithms/__init__.py
@@ -46,14 +46,19 @@
"""
from modopt.opt.algorithms.base import SetUp
-from modopt.opt.algorithms.forward_backward import (FISTA, POGM,
- ForwardBackward,
- GenForwardBackward)
-from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt,
- ADAMGradOpt,
- GenericGradOpt,
- MomentumGradOpt,
- RMSpropGradOpt,
- SAGAOptGradOpt,
- VanillaGenericGradOpt)
+from modopt.opt.algorithms.forward_backward import (
+ FISTA,
+ POGM,
+ ForwardBackward,
+ GenForwardBackward,
+)
+from modopt.opt.algorithms.gradient_descent import (
+ AdaGenericGradOpt,
+ ADAMGradOpt,
+ GenericGradOpt,
+ MomentumGradOpt,
+ RMSpropGradOpt,
+ SAGAOptGradOpt,
+ VanillaGenericGradOpt,
+)
from modopt.opt.algorithms.primal_dual import Condat
diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py
index c5a4b101..1025c2af 100644
--- a/modopt/opt/algorithms/base.py
+++ b/modopt/opt/algorithms/base.py
@@ -69,7 +69,7 @@ def __init__(
verbose=False,
progress=True,
step_size=None,
- compute_backend='numpy',
+ compute_backend="numpy",
**dummy_kwargs,
):
self.idx = 0
@@ -79,26 +79,26 @@ def __init__(
self.metrics = metrics
self.step_size = step_size
self._op_parents = (
- 'GradParent',
- 'ProximityParent',
- 'LinearParent',
- 'costObj',
+ "GradParent",
+ "ProximityParent",
+ "LinearParent",
+ "costObj",
)
self.metric_call_period = metric_call_period
# Declaration of observers for metrics
- super().__init__(['cv_metrics'])
+ super().__init__(["cv_metrics"])
for name, dic in self.metrics.items():
observer = MetricObserver(
name,
- dic['metric'],
- dic['mapping'],
- dic['cst_kwargs'],
- dic['early_stopping'],
+ dic["metric"],
+ dic["mapping"],
+ dic["cst_kwargs"],
+ dic["early_stopping"],
)
- self.add_observer('cv_metrics', observer)
+ self.add_observer("cv_metrics", observer)
xp, compute_backend = backend.get_backend(compute_backend)
self.xp = xp
@@ -111,14 +111,13 @@ def metrics(self):
@metrics.setter
def metrics(self, metrics):
-
if isinstance(metrics, type(None)):
self._metrics = {}
elif isinstance(metrics, dict):
self._metrics = metrics
else:
raise TypeError(
- 'Metrics must be a dictionary, not {0}.'.format(type(metrics)),
+ "Metrics must be a dictionary, not {0}.".format(type(metrics)),
)
def any_convergence_flag(self):
@@ -132,9 +131,7 @@ def any_convergence_flag(self):
True if any convergence criteria met
"""
- return any(
- obs.converge_flag for obs in self._observers['cv_metrics']
- )
+ return any(obs.converge_flag for obs in self._observers["cv_metrics"])
def copy_data(self, input_data):
"""Copy Data.
@@ -152,10 +149,12 @@ def copy_data(self, input_data):
Copy of input data
"""
- return self.xp.copy(backend.change_backend(
- input_data,
- self.compute_backend,
- ))
+ return self.xp.copy(
+ backend.change_backend(
+ input_data,
+ self.compute_backend,
+ )
+ )
def _check_input_data(self, input_data):
"""Check input data type.
@@ -175,7 +174,7 @@ def _check_input_data(self, input_data):
"""
if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))):
raise TypeError(
- 'Input data must be a numpy array or backend array',
+ "Input data must be a numpy array or backend array",
)
def _check_param(self, param_val):
@@ -195,7 +194,7 @@ def _check_param(self, param_val):
"""
if not isinstance(param_val, float):
- raise TypeError('Algorithm parameter must be a float value.')
+ raise TypeError("Algorithm parameter must be a float value.")
def _check_param_update(self, param_update):
"""Check algorithm parameter update methods.
@@ -213,14 +212,13 @@ def _check_param_update(self, param_update):
For invalid input type
"""
- param_conditions = (
- not isinstance(param_update, type(None))
- and not callable(param_update)
+ param_conditions = not isinstance(param_update, type(None)) and not callable(
+ param_update
)
if param_conditions:
raise TypeError(
- 'Algorithm parameter update must be a callabale function.',
+ "Algorithm parameter update must be a callabale function.",
)
def _check_operator(self, operator):
@@ -239,7 +237,7 @@ def _check_operator(self, operator):
tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)]
if not any(parent in tree for parent in self._op_parents):
- message = '{0} does not inherit an operator parent.'
+ message = "{0} does not inherit an operator parent."
warn(message.format(str(operator.__class__)))
def _compute_metrics(self):
@@ -250,7 +248,7 @@ def _compute_metrics(self):
"""
kwargs = self.get_notify_observers_kwargs()
- self.notify_observers('cv_metrics', **kwargs)
+ self.notify_observers("cv_metrics", **kwargs)
def _iterations(self, max_iter, progbar=None):
"""Iterate method.
@@ -273,7 +271,6 @@ def _iterations(self, max_iter, progbar=None):
# We do not call metrics if metrics is empty or metric call
# period is None
if self.metrics and self.metric_call_period is not None:
-
metric_conditions = (
self.idx % self.metric_call_period == 0
or self.idx == (max_iter - 1)
@@ -285,7 +282,7 @@ def _iterations(self, max_iter, progbar=None):
if self.converge:
if self.verbose:
- print(' - Converged!')
+ print(" - Converged!")
break
if progbar:
diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py
index 83f45e8b..6f66ec26 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/modopt/opt/algorithms/forward_backward.py
@@ -52,12 +52,12 @@ class FISTA(object):
"""
_restarting_strategies = (
- 'adaptive', # option 1 in alg 4
- 'adaptive-i',
- 'adaptive-1',
- 'adaptive-ii', # option 2 in alg 4
- 'adaptive-2',
- 'greedy', # alg 5
+ "adaptive", # option 1 in alg 4
+ "adaptive-i",
+ "adaptive-1",
+ "adaptive-ii", # option 2 in alg 4
+ "adaptive-2",
+ "greedy", # alg 5
None, # no restarting
)
@@ -73,26 +73,28 @@ def __init__(
r_lazy=4,
**kwargs,
):
-
if isinstance(a_cd, type(None)):
- self.mode = 'regular'
+ self.mode = "regular"
self.p_lazy = p_lazy
self.q_lazy = q_lazy
self.r_lazy = r_lazy
elif a_cd > 2:
- self.mode = 'CD'
+ self.mode = "CD"
self.a_cd = a_cd
self._n = 0
else:
raise ValueError(
- 'a_cd must either be None (for regular mode) or a number > 2',
+ "a_cd must either be None (for regular mode) or a number > 2",
)
if restart_strategy in self._restarting_strategies:
self._check_restart_params(
- restart_strategy, min_beta, s_greedy, xi_restart,
+ restart_strategy,
+ min_beta,
+ s_greedy,
+ xi_restart,
)
self.restart_strategy = restart_strategy
self.min_beta = min_beta
@@ -100,10 +102,10 @@ def __init__(
self.xi_restart = xi_restart
else:
- message = 'Restarting strategy must be one of {0}.'
+ message = "Restarting strategy must be one of {0}."
raise ValueError(
message.format(
- ', '.join(self._restarting_strategies),
+ ", ".join(self._restarting_strategies),
),
)
self._t_now = 1.0
@@ -155,22 +157,20 @@ def _check_restart_params(
if restart_strategy is None:
return True
- if self.mode != 'regular':
+ if self.mode != "regular":
raise ValueError(
- 'Restarting strategies can only be used with regular mode.',
+ "Restarting strategies can only be used with regular mode.",
)
- greedy_params_check = (
- min_beta is None or s_greedy is None or s_greedy <= 1
- )
+ greedy_params_check = min_beta is None or s_greedy is None or s_greedy <= 1
- if restart_strategy == 'greedy' and greedy_params_check:
+ if restart_strategy == "greedy" and greedy_params_check:
raise ValueError(
- 'You need a min_beta and an s_greedy > 1 for greedy restart.',
+ "You need a min_beta and an s_greedy > 1 for greedy restart.",
)
if xi_restart is None or xi_restart >= 1:
- raise ValueError('You need a xi_restart < 1 for restart.')
+ raise ValueError("You need a xi_restart < 1 for restart.")
return True
@@ -210,12 +210,12 @@ def is_restart(self, z_old, x_new, x_old):
criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0
if criterion:
- if 'adaptive' in self.restart_strategy:
+ if "adaptive" in self.restart_strategy:
self.r_lazy *= self.xi_restart
- if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}:
+ if self.restart_strategy in {"adaptive-ii", "adaptive-2"}:
self._t_now = 1
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
cur_delta = xp.linalg.norm(x_new - x_old)
if self._delta0 is None:
self._delta0 = self.s_greedy * cur_delta
@@ -269,17 +269,17 @@ def update_lambda(self, *args, **kwargs):
Implements steps 3 and 4 from algoritm 10.7 in :cite:`bauschke2009`.
"""
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
return 2
# Steps 3 and 4 from alg.10.7.
self._t_prev = self._t_now
- if self.mode == 'regular':
- sqrt_part = self.r_lazy * self._t_prev ** 2 + self.q_lazy
+ if self.mode == "regular":
+ sqrt_part = self.r_lazy * self._t_prev**2 + self.q_lazy
self._t_now = self.p_lazy + np.sqrt(sqrt_part) * 0.5
- elif self.mode == 'CD':
+ elif self.mode == "CD":
self._t_now = (self._n + self.a_cd - 1) / self.a_cd
self._n += 1
@@ -344,18 +344,17 @@ def __init__(
x,
grad,
prox,
- cost='auto',
+ cost="auto",
beta_param=1.0,
lambda_param=1.0,
beta_update=None,
- lambda_update='fista',
+ lambda_update="fista",
auto_iterate=True,
metric_call_period=5,
metrics=None,
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -376,7 +375,7 @@ def __init__(
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -384,7 +383,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -400,7 +399,7 @@ def __init__(
# Set the algorithm parameter update methods
self._check_param_update(beta_update)
self._beta_update = beta_update
- if isinstance(lambda_update, str) and lambda_update == 'fista':
+ if isinstance(lambda_update, str) and lambda_update == "fista":
fista = FISTA(**kwargs)
self._lambda_update = fista.update_lambda
self._is_restart = fista.is_restart
@@ -462,9 +461,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -500,9 +498,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z_new,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -513,7 +511,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -577,7 +575,7 @@ def __init__(
x,
grad,
prox_list,
- cost='auto',
+ cost="auto",
gamma_param=1.0,
lambda_param=1.0,
gamma_update=None,
@@ -589,7 +587,6 @@ def __init__(
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -609,7 +606,7 @@ def __init__(
self._prox_list = self.xp.array(prox_list)
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad] + prox_list)
else:
self._cost_func = cost
@@ -617,7 +614,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -641,9 +638,7 @@ def __init__(
self._set_weights(weights)
# Set initial z
- self._z = self.xp.array([
- self._x_old for i in range(self._prox_list.size)
- ])
+ self._z = self.xp.array([self._x_old for i in range(self._prox_list.size)])
# Automatically run the algorithm
if auto_iterate:
@@ -673,25 +668,25 @@ def _set_weights(self, weights):
self._prox_list.size,
)
elif not isinstance(weights, (list, tuple, np.ndarray)):
- raise TypeError('Weights must be provided as a list.')
+ raise TypeError("Weights must be provided as a list.")
weights = self.xp.array(weights)
if not np.issubdtype(weights.dtype, np.floating):
- raise ValueError('Weights must be list of float values.')
+ raise ValueError("Weights must be list of float values.")
if weights.size != self._prox_list.size:
raise ValueError(
- 'The number of weights must match the number of proximity '
- + 'operators.',
+ "The number of weights must match the number of proximity "
+ + "operators.",
)
expected_weight_sum = 1.0
if self.xp.sum(weights) != expected_weight_sum:
raise ValueError(
- 'Proximity operator weights must sum to 1.0. Current sum of '
- + 'weights = {0}'.format(self.xp.sum(weights)),
+ "Proximity operator weights must sum to 1.0. Current sum of "
+ + "weights = {0}".format(self.xp.sum(weights)),
)
self._weights = weights
@@ -726,9 +721,7 @@ def _update(self):
# Update z values.
for i in range(self._prox_list.size):
- z_temp = (
- 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
- )
+ z_temp = 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
z_prox = self._prox_list[i].op(
z_temp,
extra_factor=self._gamma / self._weights[i],
@@ -784,9 +777,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -797,7 +790,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -871,7 +864,7 @@ def __init__(
z,
grad,
prox,
- cost='auto',
+ cost="auto",
linear=None,
beta_param=1.0,
sigma_bar=1.0,
@@ -880,7 +873,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -905,7 +897,7 @@ def __init__(
self._grad = grad
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -918,7 +910,7 @@ def __init__(
for param_val in (beta_param, sigma_bar):
self._check_param(param_val)
if sigma_bar < 0 or sigma_bar > 1:
- raise ValueError('The sigma bar parameter needs to be in [0, 1]')
+ raise ValueError("The sigma bar parameter needs to be in [0, 1]")
self._beta = self.step_size or beta_param
self._sigma_bar = sigma_bar
@@ -947,13 +939,13 @@ def _update(self):
self._u_new = self._x_old - self._beta * self._grad.grad
# Step 5 from alg. 3
- self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2))
+ self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2))
# Step 6 from alg. 3
t_shifted_ratio = (self._t_old - 1) / self._t_new
sigma_t_ratio = self._sigma * self._t_old / self._t_new
beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi
- self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z)
+ self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z)
self._z += self._u_new
self._z += t_shifted_ratio * (self._u_new - self._u_old)
self._z += sigma_t_ratio * (self._u_new - self._x_old)
@@ -972,9 +964,7 @@ def _update(self):
self._y_new = self._x_old - self._beta * self._g_new
# Step 11 from alg. 3
- restart_crit = (
- self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
- )
+ restart_crit = self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
if restart_crit:
self._t_new = 1
self._sigma = 1
@@ -992,9 +982,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -1030,14 +1019,14 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'u_new': self._u_new,
- 'x_new': self._linear.adj_op(self._x_new),
- 'y_new': self._y_new,
- 'z_new': self._z,
- 'xi': self._xi,
- 'sigma': self._sigma,
- 't': self._t_new,
- 'idx': self.idx,
+ "u_new": self._u_new,
+ "x_new": self._linear.adj_op(self._x_new),
+ "y_new": self._y_new,
+ "z_new": self._z,
+ "xi": self._xi,
+ "sigma": self._sigma,
+ "t": self._t_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -1048,6 +1037,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/algorithms/gradient_descent.py b/modopt/opt/algorithms/gradient_descent.py
index f3fe4b10..d3af1686 100644
--- a/modopt/opt/algorithms/gradient_descent.py
+++ b/modopt/opt/algorithms/gradient_descent.py
@@ -103,7 +103,7 @@ def __init__(
self._check_operator(operator)
self._grad = grad
self._prox = prox
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -157,9 +157,8 @@ def _update(self):
self._eta = self._eta_update(self._eta, self.idx)
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def _update_grad_dir(self, grad):
@@ -208,10 +207,10 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._x_new,
- 'dir_grad': self._dir_grad,
- 'speed_grad': self._speed_grad,
- 'idx': self.idx,
+ "x_new": self._x_new,
+ "dir_grad": self._dir_grad,
+ "speed_grad": self._speed_grad,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -222,7 +221,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -308,7 +307,7 @@ class RMSpropGradOpt(GenericGradOpt):
def __init__(self, *args, gamma=0.5, **kwargs):
super().__init__(*args, **kwargs)
if gamma < 0 or gamma > 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
self._check_param(gamma)
self._gamma = gamma
@@ -405,9 +404,9 @@ def __init__(self, *args, gamma=0.9, beta=0.9, **kwargs):
self._check_param(gamma)
self._check_param(beta)
if gamma < 0 or gamma >= 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
if beta < 0 or beta >= 1:
- raise ValueError('beta is outside of range [0,1]')
+ raise ValueError("beta is outside of range [0,1]")
self._gamma = gamma
self._beta = beta
self._beta_pow = 1
diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py
index d5bdd431..3b160d7b 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/modopt/opt/algorithms/primal_dual.py
@@ -81,7 +81,7 @@ def __init__(
prox,
prox_dual,
linear=None,
- cost='auto',
+ cost="auto",
reweight=None,
rho=0.5,
sigma=1.0,
@@ -96,7 +96,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -123,12 +122,14 @@ def __init__(
self._linear = Identity()
else:
self._linear = linear
- if cost == 'auto':
- self._cost_func = costObj([
- self._grad,
- self._prox,
- self._prox_dual,
- ])
+ if cost == "auto":
+ self._cost_func = costObj(
+ [
+ self._grad,
+ self._prox,
+ self._prox_dual,
+ ]
+ )
else:
self._cost_func = cost
@@ -187,22 +188,17 @@ def _update(self):
self._grad.get_grad(self._x_old)
x_prox = self._prox.op(
- self._x_old - self._tau * self._grad.grad - self._tau
- * self._linear.adj_op(self._y_old),
+ self._x_old
+ - self._tau * self._grad.grad
+ - self._tau * self._linear.adj_op(self._y_old),
)
# Step 2 from eq.9.
- y_temp = (
- self._y_old + self._sigma
- * self._linear.op(2 * x_prox - self._x_old)
- )
+ y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)
- y_prox = (
- y_temp - self._sigma
- * self._prox_dual.op(
- y_temp / self._sigma,
- extra_factor=(1.0 / self._sigma),
- )
+ y_prox = y_temp - self._sigma * self._prox_dual.op(
+ y_temp / self._sigma,
+ extra_factor=(1.0 / self._sigma),
)
# Step 3 from eq.9.
@@ -220,9 +216,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new, self._y_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new, self._y_new
)
def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
@@ -267,7 +262,7 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
- return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx}
+ return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx}
def retrieve_outputs(self):
"""Retrieve outputs.
@@ -277,6 +272,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/cost.py b/modopt/opt/cost.py
index 3cdfcc50..01c77f2f 100644
--- a/modopt/opt/cost.py
+++ b/modopt/opt/cost.py
@@ -79,7 +79,6 @@ def __init__(
verbose=True,
plot_output=None,
):
-
self._operators = operators
if not isinstance(operators, type(None)):
self._check_operators()
@@ -107,13 +106,11 @@ def _check_operators(self):
"""
if not isinstance(self._operators, (list, tuple, np.ndarray)):
- message = (
- 'Input operators must be provided as a list, not {0}'
- )
+ message = "Input operators must be provided as a list, not {0}"
raise TypeError(message.format(type(self._operators)))
for op in self._operators:
- if not hasattr(op, 'cost'):
+ if not hasattr(op, "cost"):
raise ValueError('Operators must contain "cost" method.')
op.cost = check_callable(op.cost)
@@ -137,20 +134,19 @@ def _check_cost(self):
# Check if enough cost values have been collected
if len(self._test_list) == self._test_range:
-
# The mean of the first half of the test list
t1 = xp.mean(
- xp.array(self._test_list[len(self._test_list) // 2:]),
+ xp.array(self._test_list[len(self._test_list) // 2 :]),
axis=0,
)
# The mean of the second half of the test list
t2 = xp.mean(
- xp.array(self._test_list[:len(self._test_list) // 2]),
+ xp.array(self._test_list[: len(self._test_list) // 2]),
axis=0,
)
# Calculate the change across the test list
if xp.around(t1, decimals=16):
- cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1))
+ cost_diff = xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)
else:
cost_diff = 0
@@ -158,9 +154,9 @@ def _check_cost(self):
self._test_list = []
if self._verbose:
- print(' - CONVERGENCE TEST - ')
- print(' - CHANGE IN COST:', cost_diff)
- print('')
+ print(" - CONVERGENCE TEST - ")
+ print(" - CHANGE IN COST:", cost_diff)
+ print("")
# Check for convergence
return cost_diff <= self._tolerance
@@ -207,8 +203,7 @@ def get_cost(self, *args, **kwargs):
"""
# Check if the cost should be calculated
test_conditions = (
- self._cost_interval is None
- or self._iteration % self._cost_interval
+ self._cost_interval is None or self._iteration % self._cost_interval
)
if test_conditions:
@@ -216,15 +211,15 @@ def get_cost(self, *args, **kwargs):
else:
if self._verbose:
- print(' - ITERATION:', self._iteration)
+ print(" - ITERATION:", self._iteration)
# Calculate the current cost
self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs)
self._cost_list.append(self.cost)
if self._verbose:
- print(' - COST:', self.cost)
- print('')
+ print(" - COST:", self.cost)
+ print("")
# Test for convergence
test_result = self._check_cost()
diff --git a/modopt/opt/gradient.py b/modopt/opt/gradient.py
index caa8fa9d..5e0442aa 100644
--- a/modopt/opt/gradient.py
+++ b/modopt/opt/gradient.py
@@ -71,7 +71,6 @@ def __init__(
input_data_writeable=False,
verbose=True,
):
-
self.verbose = verbose
self._input_data_writeable = input_data_writeable
self._grad_data_type = data_type
@@ -100,7 +99,6 @@ def obs_data(self):
@obs_data.setter
def obs_data(self, input_data):
-
if self._grad_data_type in {float, np.floating}:
input_data = check_float(input_data)
check_npndarray(
@@ -128,7 +126,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -147,7 +144,6 @@ def trans_op(self):
@trans_op.setter
def trans_op(self, operator):
-
self._trans_op = check_callable(operator)
@property
@@ -157,7 +153,6 @@ def get_grad(self):
@get_grad.setter
def get_grad(self, method):
-
self._get_grad = check_callable(method)
@property
@@ -167,7 +162,6 @@ def grad(self):
@grad.setter
def grad(self, input_value):
-
if self._grad_data_type in {float, np.floating}:
input_value = check_float(input_value)
self._grad = input_value
@@ -179,7 +173,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
def trans_op_op(self, input_data):
@@ -243,7 +236,6 @@ class GradBasic(GradParent):
"""
def __init__(self, *args, **kwargs):
-
super().__init__(*args, **kwargs)
self.get_grad = self._get_grad_method
self.cost = self._cost_method
@@ -289,7 +281,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - DATA FIDELITY (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - DATA FIDELITY (X):", cost_val)
return cost_val
diff --git a/modopt/opt/linear.py b/modopt/opt/linear.py
index 83241625..0bc1f3f0 100644
--- a/modopt/opt/linear.py
+++ b/modopt/opt/linear.py
@@ -39,7 +39,6 @@ class LinearParent(object):
"""
def __init__(self, op, adj_op):
-
self.op = op
self.adj_op = adj_op
@@ -50,7 +49,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -60,7 +58,6 @@ def adj_op(self):
@adj_op.setter
def adj_op(self, operator):
-
self._adj_op = check_callable(operator)
@@ -76,7 +73,6 @@ class Identity(LinearParent):
"""
def __init__(self):
-
self.op = lambda input_data: input_data
self.adj_op = self.op
@@ -118,8 +114,7 @@ class WaveletConvolve(LinearParent):
"""
- def __init__(self, filters, method='scipy'):
-
+ def __init__(self, filters, method="scipy"):
self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
input_data,
@@ -171,7 +166,6 @@ class LinearCombo(LinearParent):
"""
def __init__(self, operators, weights=None):
-
operators, weights = self._check_inputs(operators, weights)
self.operators = operators
self.weights = weights
@@ -204,14 +198,13 @@ def _check_type(self, input_val):
"""
if not isinstance(input_val, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, input must be a list, tuple or numpy '
- + 'array.',
+ "Invalid input type, input must be a list, tuple or numpy " + "array.",
)
input_val = np.array(input_val)
if not input_val.size:
- raise ValueError('Input list is empty.')
+ raise ValueError("Input list is empty.")
return input_val
@@ -244,11 +237,10 @@ def _check_inputs(self, operators, weights):
operators = self._check_type(operators)
for operator in operators:
-
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'adj_op'):
+ if not hasattr(operator, "adj_op"):
raise ValueError('Operators must contain "adj_op" method.')
operator.op = check_callable(operator.op)
@@ -259,12 +251,11 @@ def _check_inputs(self, operators, weights):
if weights.size != operators.size:
raise ValueError(
- 'The number of weights must match the number of '
- + 'operators.',
+ "The number of weights must match the number of " + "operators.",
)
if not np.issubdtype(weights.dtype, np.floating):
- raise TypeError('The weights must be a list of float values.')
+ raise TypeError("The weights must be a list of float values.")
return operators, weights
diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py
index e8492367..bf7b6140 100644
--- a/modopt/opt/proximity.py
+++ b/modopt/opt/proximity.py
@@ -47,7 +47,6 @@ class ProximityParent(object):
"""
def __init__(self, op, cost):
-
self.op = op
self.cost = cost
@@ -58,7 +57,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -78,7 +76,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
@@ -98,7 +95,6 @@ class IdentityProx(ProximityParent):
"""
def __init__(self):
-
self.op = lambda x_val: x_val
self.cost = lambda x_val: 0
@@ -116,7 +112,6 @@ class Positivity(ProximityParent):
"""
def __init__(self):
-
self.op = lambda input_data: positive(input_data)
self.cost = self._cost_method
@@ -139,8 +134,8 @@ def _cost_method(self, *args, **kwargs):
``0.0``
"""
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - Min (X):', np.min(args[0]))
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - Min (X):", np.min(args[0]))
return 0
@@ -166,8 +161,7 @@ class SparseThreshold(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
-
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
self._thresh_type = thresh_type
@@ -217,8 +211,8 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0])))
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L1 NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - L1 NORM (X):", cost_val)
return cost_val
@@ -269,12 +263,11 @@ class LowRankMatrix(ProximityParent):
def __init__(
self,
threshold,
- thresh_type='soft',
- lowr_type='standard',
+ thresh_type="soft",
+ lowr_type="standard",
initial_rank=None,
operator=None,
):
-
self.thresh = threshold
self.thresh_type = thresh_type
self.lowr_type = lowr_type
@@ -311,13 +304,13 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
"""
# Update threshold with extra factor.
threshold = self.thresh * extra_factor
- if self.lowr_type == 'standard' and self.rank is None and rank is None:
+ if self.lowr_type == "standard" and self.rank is None and rank is None:
data_matrix = svd_thresh(
cube2matrix(input_data),
threshold,
thresh_type=self.thresh_type,
)
- elif self.lowr_type == 'standard':
+ elif self.lowr_type == "standard":
data_matrix, update_rank = svd_thresh_coef_fast(
cube2matrix(input_data),
threshold,
@@ -327,7 +320,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
)
self.rank = update_rank # save for future use
- elif self.lowr_type == 'ngole':
+ elif self.lowr_type == "ngole":
data_matrix = svd_thresh_coef(
cube2matrix(input_data),
self.operator,
@@ -335,7 +328,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
thresh_type=self.thresh_type,
)
else:
- raise ValueError('lowr_type should be standard or ngole')
+ raise ValueError("lowr_type should be standard or ngole")
# Return updated data.
return matrix2cube(data_matrix, input_data.shape[1:])
@@ -361,8 +354,8 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = self.thresh * nuclear_norm(cube2matrix(args[0]))
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - NUCLEAR NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - NUCLEAR NORM (X):", cost_val)
return cost_val
@@ -466,7 +459,6 @@ class ProximityCombo(ProximityParent):
"""
def __init__(self, operators):
-
operators = self._check_operators(operators)
self.operators = operators
self.op = self._op_method
@@ -502,19 +494,19 @@ def _check_operators(self, operators):
"""
if not isinstance(operators, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, operators must be a list, tuple or '
- + 'numpy array.',
+ "Invalid input type, operators must be a list, tuple or "
+ + "numpy array.",
)
operators = np.array(operators)
if not operators.size:
- raise ValueError('Operator list is empty.')
+ raise ValueError("Operator list is empty.")
for operator in operators:
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'cost'):
+ if not hasattr(operator, "cost"):
raise ValueError('Operators must contain "cost" method.')
operator.op = check_callable(operator.op)
operator.cost = check_callable(operator.cost)
@@ -569,10 +561,12 @@ def _cost_method(self, *args, **kwargs):
Combinded cost components
"""
- return np.sum([
- operator.cost(input_data)
- for operator, input_data in zip(self.operators, args[0])
- ])
+ return np.sum(
+ [
+ operator.cost(input_data)
+ for operator, input_data in zip(self.operators, args[0])
+ ]
+ )
class OrderedWeightedL1Norm(ProximityParent):
@@ -613,16 +607,16 @@ class OrderedWeightedL1Norm(ProximityParent):
def __init__(self, weights):
if not import_sklearn: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Learn package not found see '
- + 'documentation for details: '
- + 'https://cea-cosmic.github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Learn package not found see "
+ + "documentation for details: "
+ + "https://cea-cosmic.github.io/ModOpt/#optional-packages",
)
if np.max(np.diff(weights)) > 0:
- raise ValueError('Weights must be non increasing')
+ raise ValueError("Weights must be non increasing")
self.weights = weights.flatten()
if (self.weights < 0).any():
raise ValueError(
- 'The weight values must be provided in descending order',
+ "The weight values must be provided in descending order",
)
self.op = self._op_method
self.cost = self._cost_method
@@ -660,7 +654,9 @@ def _op_method(self, input_data, extra_factor=1.0):
# Projection onto the monotone non-negative cone using
# isotonic_regression
data_abs = isotonic_regression(
- data_abs - threshold, y_min=0, increasing=False,
+ data_abs - threshold,
+ y_min=0,
+ increasing=False,
)
# Unsorting the data
@@ -668,7 +664,7 @@ def _op_method(self, input_data, extra_factor=1.0):
data_abs_unsorted[data_abs_sort_idx] = data_abs
# Putting the sign back
- with np.errstate(invalid='ignore'):
+ with np.errstate(invalid="ignore"):
sign_data = data_squeezed / np.abs(data_squeezed)
# Removing NAN caused by the sign
@@ -698,8 +694,8 @@ def _cost_method(self, *args, **kwargs):
self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - OWL NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - OWL NORM (X):", cost_val)
return cost_val
@@ -730,8 +726,7 @@ class Ridge(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
-
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
self.op = self._op_method
@@ -782,8 +777,8 @@ def _cost_method(self, *args, **kwargs):
np.abs(self.weights * self._linear.op(args[0]) ** 2),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L2 NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - L2 NORM (X):", cost_val)
return cost_val
@@ -818,7 +813,6 @@ class ElasticNet(ProximityParent):
"""
def __init__(self, linear, alpha, beta):
-
self._linear = linear
self.alpha = alpha
self.beta = beta
@@ -844,8 +838,8 @@ def _op_method(self, input_data, extra_factor=1.0):
"""
soft_threshold = self.beta * extra_factor
- normalization = (self.alpha * 2 * extra_factor + 1)
- return thresh(input_data, soft_threshold, 'soft') / normalization
+ normalization = self.alpha * 2 * extra_factor + 1
+ return thresh(input_data, soft_threshold, "soft") / normalization
def _cost_method(self, *args, **kwargs):
"""Calculate Ridge component of the cost.
@@ -871,8 +865,8 @@ def _cost_method(self, *args, **kwargs):
+ np.abs(self.beta * self._linear.op(args[0])),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - ELASTIC NET (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - ELASTIC NET (X):", cost_val)
return cost_val
@@ -938,7 +932,7 @@ def k_value(self):
def k_value(self, k_val):
if k_val < 1:
raise ValueError(
- 'The k parameter should be greater or equal than 1',
+ "The k parameter should be greater or equal than 1",
)
self._k_value = k_val
@@ -983,7 +977,7 @@ def _compute_theta(self, input_data, alpha, extra_factor=1.0):
alpha_beta = alpha_input - self.beta * extra_factor
theta = alpha_beta * ((alpha_beta <= 1) & (alpha_beta >= 0))
theta = np.nan_to_num(theta)
- theta += (alpha_input > (self.beta * extra_factor + 1))
+ theta += alpha_input > (self.beta * extra_factor + 1)
return theta
def _interpolate(self, alpha0, alpha1, sum0, sum1):
@@ -1074,12 +1068,10 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
midpoint = 0
while (first_idx <= last_idx) and not found and (cnt < alpha.shape[0]):
-
midpoint = (first_idx + last_idx) // 2
cnt += 1
if prev_midpoint == midpoint:
-
# Particular case
sum0 = self._compute_theta(
data_abs,
@@ -1092,11 +1084,11 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
extra_factor,
).sum()
- if (np.abs(sum0 - self._k_value) <= tolerance):
+ if np.abs(sum0 - self._k_value) <= tolerance:
found = True
midpoint = first_idx
- if (np.abs(sum1 - self._k_value) <= tolerance):
+ if np.abs(sum1 - self._k_value) <= tolerance:
found = True
midpoint = last_idx - 1
# -1 because output is index such that
@@ -1141,13 +1133,17 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
if found:
return (
- midpoint, alpha[midpoint], alpha[midpoint + 1], sum0, sum1,
+ midpoint,
+ alpha[midpoint],
+ alpha[midpoint + 1],
+ sum0,
+ sum1,
)
raise ValueError(
- 'Cannot find the coordinate of alpha (i) such '
- + 'that sum(theta(alpha[i])) =< k and '
- + 'sum(theta(alpha[i+1])) >= k ',
+ "Cannot find the coordinate of alpha (i) such "
+ + "that sum(theta(alpha[i])) =< k and "
+ + "sum(theta(alpha[i+1])) >= k ",
)
def _find_alpha(self, input_data, extra_factor=1.0):
@@ -1173,13 +1169,11 @@ def _find_alpha(self, input_data, extra_factor=1.0):
# Computes the alpha^i points line 1 in Algorithm 1.
alpha = np.zeros((data_size * 2))
data_abs = np.abs(input_data)
- alpha[:data_size] = (
- (self.beta * extra_factor)
- / (data_abs + sys.float_info.epsilon)
+ alpha[:data_size] = (self.beta * extra_factor) / (
+ data_abs + sys.float_info.epsilon
)
- alpha[data_size:] = (
- (self.beta * extra_factor + 1)
- / (data_abs + sys.float_info.epsilon)
+ alpha[data_size:] = (self.beta * extra_factor + 1) / (
+ data_abs + sys.float_info.epsilon
)
alpha = np.sort(np.unique(alpha))
@@ -1216,8 +1210,8 @@ def _op_method(self, input_data, extra_factor=1.0):
k_max = np.prod(data_shape)
if self._k_value > k_max:
warn(
- 'K value of the K-support norm is greater than the input '
- + 'dimension, its value will be set to {0}'.format(k_max),
+ "K value of the K-support norm is greater than the input "
+ + "dimension, its value will be set to {0}".format(k_max),
)
self._k_value = k_max
@@ -1229,8 +1223,7 @@ def _op_method(self, input_data, extra_factor=1.0):
# Computes line 5. in Algorithm 1.
rslt = np.nan_to_num(
- (input_data.flatten() * theta)
- / (theta + self.beta * extra_factor),
+ (input_data.flatten() * theta) / (theta + self.beta * extra_factor),
)
return rslt.reshape(data_shape)
@@ -1271,25 +1264,20 @@ def _find_q(self, sorted_data):
found = True
q_val = 0
- elif (
- (sorted_data[self._k_value - 1:].sum())
- <= sorted_data[self._k_value - 1]
- ):
+ elif (sorted_data[self._k_value - 1 :].sum()) <= sorted_data[self._k_value - 1]:
found = True
q_val = self._k_value - 1
while (
- not found and not cnt == self._k_value
+ not found
+ and not cnt == self._k_value
and (first_idx <= last_idx < self._k_value)
):
-
q_val = (first_idx + last_idx) // 2
cnt += 1
l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val)
- if (
- sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]
- ):
+ if sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]:
found = True
else:
@@ -1324,15 +1312,12 @@ def _cost_method(self, *args, **kwargs):
data_abs = data_abs[ix] # Sorted absolute value of the data
q_val = self._find_q(data_abs)
cost_val = (
- (
- np.sum(data_abs[:q_val] ** 2) * 0.5
- + np.sum(data_abs[q_val:]) ** 2
- / (self._k_value - q_val)
- ) * self.beta
- )
+ np.sum(data_abs[:q_val] ** 2) * 0.5
+ + np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val)
+ ) * self.beta
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - K-SUPPORT NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - K-SUPPORT NORM (X):", cost_val)
return cost_val
diff --git a/modopt/opt/reweight.py b/modopt/opt/reweight.py
index 8c4f2449..d0daa841 100644
--- a/modopt/opt/reweight.py
+++ b/modopt/opt/reweight.py
@@ -45,7 +45,6 @@ class cwbReweight(object):
"""
def __init__(self, weights, thresh_factor=1.0, verbose=False):
-
self.weights = check_float(weights)
self.original_weights = np.copy(self.weights)
self.thresh_factor = check_float(thresh_factor)
@@ -81,7 +80,7 @@ def reweight(self, input_data):
"""
if self.verbose:
- print(' - Reweighting: {0}'.format(self._rw_num))
+ print(" - Reweighting: {0}".format(self._rw_num))
self._rw_num += 1
@@ -89,7 +88,7 @@ def reweight(self, input_data):
if input_data.shape != self.weights.shape:
raise ValueError(
- 'Input data must have the same shape as the initial weights.',
+ "Input data must have the same shape as the initial weights.",
)
thresh_weights = self.thresh_factor * self.original_weights
diff --git a/modopt/plot/__init__.py b/modopt/plot/__init__.py
index 28d60be6..da6e096c 100644
--- a/modopt/plot/__init__.py
+++ b/modopt/plot/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['cost_plot']
+__all__ = ["cost_plot"]
diff --git a/modopt/plot/cost_plot.py b/modopt/plot/cost_plot.py
index aa855eaa..36958450 100644
--- a/modopt/plot/cost_plot.py
+++ b/modopt/plot/cost_plot.py
@@ -37,20 +37,20 @@ def plotCost(cost_list, output=None):
"""
if import_fail:
- raise ImportError('Matplotlib package not found')
+ raise ImportError("Matplotlib package not found")
else:
if isinstance(output, type(None)):
- file_name = 'cost_function.png'
+ file_name = "cost_function.png"
else:
- file_name = '{0}_cost_function.png'.format(output)
+ file_name = "{0}_cost_function.png".format(output)
plt.figure()
- plt.plot(np.log10(cost_list), 'r-')
- plt.title('Cost Function')
- plt.xlabel('Iteration')
- plt.ylabel(r'$\log_{10}$ Cost')
+ plt.plot(np.log10(cost_list), "r-")
+ plt.title("Cost Function")
+ plt.xlabel("Iteration")
+ plt.ylabel(r"$\log_{10}$ Cost")
plt.savefig(file_name)
plt.close()
- print(' - Saving cost function data to:', file_name)
+ print(" - Saving cost function data to:", file_name)
diff --git a/modopt/signal/__init__.py b/modopt/signal/__init__.py
index dbc6d053..09b2d2c4 100644
--- a/modopt/signal/__init__.py
+++ b/modopt/signal/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['filter', 'noise', 'positivity', 'svd', 'validation', 'wavelet']
+__all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"]
diff --git a/modopt/signal/filter.py b/modopt/signal/filter.py
index 84dd8160..2c7d8626 100644
--- a/modopt/signal/filter.py
+++ b/modopt/signal/filter.py
@@ -81,7 +81,7 @@ def mex_hat(data_point, sigma):
sigma = check_float(sigma)
xs = (data_point / sigma) ** 2
- factor = 2 * (3 * sigma) ** -0.5 * np.pi ** -0.25
+ factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25
return factor * (1 - xs) * np.exp(-0.5 * xs)
diff --git a/modopt/signal/noise.py b/modopt/signal/noise.py
index a59d5553..fadf5308 100644
--- a/modopt/signal/noise.py
+++ b/modopt/signal/noise.py
@@ -15,7 +15,7 @@
from modopt.base.backend import get_array_module
-def add_noise(input_data, sigma=1.0, noise_type='gauss'):
+def add_noise(input_data, sigma=1.0, noise_type="gauss"):
"""Add noise to data.
This method adds Gaussian or Poisson noise to the input data.
@@ -70,7 +70,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
"""
input_data = np.array(input_data)
- if noise_type not in {'gauss', 'poisson'}:
+ if noise_type not in {"gauss", "poisson"}:
raise ValueError(
'Invalid noise type. Options are "gauss" or "poisson"',
)
@@ -78,14 +78,13 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
if isinstance(sigma, (list, tuple, np.ndarray)):
if len(sigma) != input_data.shape[0]:
raise ValueError(
- 'Number of sigma values must match first dimension of input '
- + 'data',
+ "Number of sigma values must match first dimension of input " + "data",
)
- if noise_type == 'gauss':
+ if noise_type == "gauss":
random = np.random.randn(*input_data.shape)
- elif noise_type == 'poisson':
+ elif noise_type == "poisson":
random = np.random.poisson(np.abs(input_data))
if isinstance(sigma, (int, float)):
@@ -96,7 +95,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
return input_data + noise
-def thresh(input_data, threshold, threshold_type='hard'):
+def thresh(input_data, threshold, threshold_type="hard"):
r"""Threshold data.
This method perfoms hard or soft thresholding on the input data.
@@ -169,12 +168,12 @@ def thresh(input_data, threshold, threshold_type='hard'):
input_data = xp.array(input_data)
- if threshold_type not in {'hard', 'soft'}:
+ if threshold_type not in {"hard", "soft"}:
raise ValueError(
'Invalid threshold type. Options are "hard" or "soft"',
)
- if threshold_type == 'soft':
+ if threshold_type == "soft":
denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data))
max_value = xp.maximum((1.0 - threshold / denominator), 0)
diff --git a/modopt/signal/positivity.py b/modopt/signal/positivity.py
index c19ba62c..5c4b795b 100644
--- a/modopt/signal/positivity.py
+++ b/modopt/signal/positivity.py
@@ -47,7 +47,7 @@ def pos_recursive(input_data):
Positive coefficients
"""
- if input_data.dtype == 'O':
+ if input_data.dtype == "O":
res = np.array([pos_recursive(elem) for elem in input_data], dtype="object")
else:
@@ -97,15 +97,15 @@ def positive(input_data, ragged=False):
"""
if not isinstance(input_data, (int, float, list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid data type, input must be `int`, `float`, `list`, '
- + '`tuple` or `np.ndarray`.',
+ "Invalid data type, input must be `int`, `float`, `list`, "
+ + "`tuple` or `np.ndarray`.",
)
if isinstance(input_data, (int, float)):
return pos_thresh(input_data)
if ragged:
- input_data = np.array(input_data, dtype='object')
+ input_data = np.array(input_data, dtype="object")
else:
input_data = np.array(input_data)
diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py
index f3d40a51..cc204817 100644
--- a/modopt/signal/svd.py
+++ b/modopt/signal/svd.py
@@ -52,8 +52,8 @@ def find_n_pc(u_vec, factor=0.5):
"""
if np.sqrt(u_vec.shape[0]) % 1:
raise ValueError(
- 'Invalid left singular vector. The size of the first '
- + 'dimenion of ``u_vec`` must be perfect square.',
+ "Invalid left singular vector. The size of the first "
+ + "dimenion of ``u_vec`` must be perfect square.",
)
# Get the shape of the array
@@ -69,13 +69,12 @@ def find_n_pc(u_vec, factor=0.5):
]
# Return the required number of principal components.
- return np.sum([
- (
- u_val[tuple(zip(array_shape // 2))] ** 2 <= factor
- * np.sum(u_val ** 2),
- )
- for u_val in u_auto
- ])
+ return np.sum(
+ [
+ (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),)
+ for u_val in u_auto
+ ]
+ )
def calculate_svd(input_data):
@@ -101,17 +100,17 @@ def calculate_svd(input_data):
"""
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise TypeError('Input data must be a 2D np.ndarray.')
+ raise TypeError("Input data must be a 2D np.ndarray.")
return svd(
input_data,
check_finite=False,
- lapack_driver='gesvd',
+ lapack_driver="gesvd",
full_matrices=False,
)
-def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
+def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"):
"""Threshold the singular values.
This method thresholds the input data using singular value decomposition.
@@ -156,16 +155,11 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
"""
less_than_zero = isinstance(n_pc, int) and n_pc <= 0
- str_not_all = isinstance(n_pc, str) and n_pc != 'all'
+ str_not_all = isinstance(n_pc, str) and n_pc != "all"
- if (
- (not isinstance(n_pc, (int, str, type(None))))
- or less_than_zero
- or str_not_all
- ):
+ if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all:
raise ValueError(
- 'Invalid value for "n_pc", specify a positive integer value or '
- + '"all"',
+ 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"',
)
# Get SVD of input data.
@@ -176,15 +170,14 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
# Find the required number of principal components if not specified.
if isinstance(n_pc, type(None)):
n_pc = find_n_pc(u_vec, factor=0.1)
- print('xxxx', n_pc, u_vec)
+ print("xxxx", n_pc, u_vec)
# If the number of PCs is too large use all of the singular values.
- if (
- (isinstance(n_pc, int) and n_pc >= s_values.size)
- or (isinstance(n_pc, str) and n_pc == 'all')
+ if (isinstance(n_pc, int) and n_pc >= s_values.size) or (
+ isinstance(n_pc, str) and n_pc == "all"
):
n_pc = s_values.size
- warn('Using all singular values.')
+ warn("Using all singular values.")
threshold = s_values[n_pc - 1]
@@ -192,7 +185,7 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
s_new = thresh(s_values, threshold, thresh_type)
if np.all(s_new == s_values):
- warn('No change to singular values.')
+ warn("No change to singular values.")
# Diagonalize the svd
s_new = np.diag(s_new)
@@ -206,7 +199,7 @@ def svd_thresh_coef_fast(
threshold,
n_vals=-1,
extra_vals=5,
- thresh_type='hard',
+ thresh_type="hard",
):
"""Threshold the singular values coefficients.
@@ -241,7 +234,7 @@ def svd_thresh_coef_fast(
ok = False
while not ok:
(u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
- ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1)
+ ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1
n_vals = min(n_vals + extra_vals, *input_data.shape)
s_values = thresh(
@@ -259,7 +252,7 @@ def svd_thresh_coef_fast(
)
-def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
+def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"):
"""Threshold the singular values coefficients.
This method thresholds the input data using singular value decomposition.
@@ -287,7 +280,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
"""
if not callable(operator):
- raise TypeError('Operator must be a callable function.')
+ raise TypeError("Operator must be a callable function.")
# Get SVD of data matrix
u_vec, s_values, v_vec = calculate_svd(input_data)
@@ -302,10 +295,9 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Compute threshold matrix.
- ti = np.array([
- np.linalg.norm(elem)
- for elem in operator(matrix2cube(u_vec, array_shape))
- ])
+ ti = np.array(
+ [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))]
+ )
threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape)
# Threshold coefficients.
diff --git a/modopt/signal/validation.py b/modopt/signal/validation.py
index 422a987b..68c1e726 100644
--- a/modopt/signal/validation.py
+++ b/modopt/signal/validation.py
@@ -54,7 +54,7 @@ def transpose_test(
"""
if not callable(operator) or not callable(operator_t):
- raise TypeError('The input operators must be callable functions.')
+ raise TypeError("The input operators must be callable functions.")
if isinstance(y_shape, type(None)):
y_shape = x_shape
@@ -73,4 +73,4 @@ def transpose_test(
x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args)))
# Test the difference between the two.
- print(' - | - | =', np.abs(mx_y - x_mty))
+ print(" - | - | =", np.abs(mx_y - x_mty))
diff --git a/modopt/signal/wavelet.py b/modopt/signal/wavelet.py
index bc4ffc70..72d608e7 100644
--- a/modopt/signal/wavelet.py
+++ b/modopt/signal/wavelet.py
@@ -58,20 +58,20 @@ def execute(command_line):
"""
if not isinstance(command_line, str):
- raise TypeError('Command line must be a string.')
+ raise TypeError("Command line must be a string.")
command = command_line.split()
process = sp.Popen(command, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = process.communicate()
- return stdout.decode('utf-8'), stderr.decode('utf-8')
+ return stdout.decode("utf-8"), stderr.decode("utf-8")
def call_mr_transform(
input_data,
- opt='',
- path='./',
+ opt="",
+ path="./",
remove_files=True,
): # pragma: no cover
"""Call ``mr_transform``.
@@ -127,26 +127,23 @@ def call_mr_transform(
"""
if not import_astropy:
- raise ImportError('Astropy package not found.')
+ raise ImportError("Astropy package not found.")
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise ValueError('Input data must be a 2D numpy array.')
+ raise ValueError("Input data must be a 2D numpy array.")
- executable = 'mr_transform'
+ executable = "mr_transform"
# Make sure mr_transform is installed.
is_executable(executable)
# Create a unique string using the current date and time.
- unique_string = (
- datetime.now().strftime('%Y.%m.%d_%H.%M.%S')
- + str(getrandbits(128))
- )
+ unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128))
# Set the ouput file names.
- file_name = '{0}mr_temp_{1}'.format(path, unique_string)
- file_fits = '{0}.fits'.format(file_name)
- file_mr = '{0}.mr'.format(file_name)
+ file_name = "{0}mr_temp_{1}".format(path, unique_string)
+ file_fits = "{0}.fits".format(file_name)
+ file_mr = "{0}.mr".format(file_name)
# Write the input data to a fits file.
fits.writeto(file_fits, input_data)
@@ -155,15 +152,15 @@ def call_mr_transform(
opt = opt.split()
# Prepare command and execute it
- command_line = ' '.join([executable] + opt + [file_fits, file_mr])
+ command_line = " ".join([executable] + opt + [file_fits, file_mr])
stdout, _ = execute(command_line)
# Check for errors
- if any(word in stdout for word in ('bad', 'Error', 'Sorry')):
+ if any(word in stdout for word in ("bad", "Error", "Sorry")):
remove(file_fits)
message = '{0} raised following exception: "{1}"'
raise RuntimeError(
- message.format(executable, stdout.rstrip('\n')),
+ message.format(executable, stdout.rstrip("\n")),
)
# Retrieve wavelet transformed data.
@@ -198,12 +195,12 @@ def trim_filter(filter_array):
min_idx = np.min(non_zero_indices, axis=-1)
max_idx = np.max(non_zero_indices, axis=-1)
- return filter_array[min_idx[0]:max_idx[0] + 1, min_idx[1]:max_idx[1] + 1]
+ return filter_array[min_idx[0] : max_idx[0] + 1, min_idx[1] : max_idx[1] + 1]
def get_mr_filters(
data_shape,
- opt='',
+ opt="",
coarse=False,
trim=False,
): # pragma: no cover
@@ -256,7 +253,7 @@ def get_mr_filters(
return mr_filters[:-1]
-def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
+def filter_convolve(input_data, filters, filter_rot=False, method="scipy"):
"""Filter convolve.
This method convolves the input image with the wavelet filters.
@@ -315,16 +312,14 @@ def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
axis=0,
)
- return np.array([
- convolve(input_data, filt, method=method) for filt in filters
- ])
+ return np.array([convolve(input_data, filt, method=method) for filt in filters])
def filter_convolve_stack(
input_data,
filters,
filter_rot=False,
- method='scipy',
+ method="scipy",
):
"""Filter convolve.
@@ -366,7 +361,9 @@ def filter_convolve_stack(
"""
# Return the convolved data cube.
- return np.array([
- filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
- for elem in input_data
- ])
+ return np.array(
+ [
+ filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
+ for elem in input_data
+ ]
+ )
diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py
index 202e541b..b3787fc6 100644
--- a/modopt/tests/test_signal.py
+++ b/modopt/tests/test_signal.py
@@ -17,6 +17,7 @@
class TestFilter:
"""Test filter module"""
+
@pytest.mark.parametrize(
("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
)
@@ -24,7 +25,6 @@ def test_gaussian_filter(self, norm, result):
"""Test gaussian filter."""
npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result)
-
def test_mex_hat(self):
"""Test mexican hat filter."""
npt.assert_almost_equal(
@@ -32,7 +32,6 @@ def test_mex_hat(self):
-0.35213905225713371,
)
-
def test_mex_hat_dir(self):
"""Test directional mexican hat filter."""
npt.assert_almost_equal(
@@ -86,13 +85,16 @@ def test_thresh(self, threshold_type, result):
noise.thresh(self.data1, 5, threshold_type=threshold_type), result
)
+
class TestPositivity:
"""Test positivity module."""
+
data1 = np.arange(9).reshape(3, 3).astype(float)
data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
data5 = np.array(
[[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
)
+
@pytest.mark.parametrize(
("value", "expected"),
[
@@ -231,6 +233,7 @@ def test_svd_thresh_coef(self, data, operator):
# TODO test_svd_thresh_coef_fast
+
class TestValidation:
"""Test validation Module."""
diff --git a/notebooks/.gitkeep b/notebooks/.gitkeep
deleted file mode 100644
index e69de29b..00000000
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..ea25e887
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,96 @@
+[project]
+name = "modopt"
+readme = "README.md"
+requires-python = ">=3.7"
+authors = [
+ { "name" = "Samuel Farrens", "email" = "samuel.farrens@cea.fr"},
+]
+maintainers = [
+ { "name" = "Samuel Farrens", "email" = "samuel.farrens@cea.fr" },
+]
+description = 'Modular Optimisation tools for soliving inverse problems.'
+dynamic = ["dependencies", "version"]
+keywords = ["Image processing", "Wavelets", "Sparsity", "MRI", "Astronomy", "Electron tomography"]
+license = { "file" = "LICENSE.txt" }
+classifiers = [
+ "Development Status :: 5 - Production/Stable",
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Operating System :: POSIX :: Linux",
+ "Operating System :: MacOS",
+ "Topic :: Scientific/Engineering",
+]
+
+[project.optional-dependencies]
+docs = [
+ "importlib_metadata",
+ "myst-parser",
+ "numpydoc",
+ "sphinx",
+ "sphinxcontrib-bibtex",
+ "sphinxawesome-theme",
+ "sphinx-gallery",
+]
+extra = [
+ "astropy",
+ "matplotlib",
+ "scikit-image<0.20",
+ "scikit-learn",
+]
+lint = [
+ "black",
+]
+release = [
+ "build",
+ "twine",
+]
+test = [
+ "pytest",
+ "pytest-black",
+ "pytest-cases",
+ "pytest-cov",
+ "pytest-emoji",
+ "pytest-pydocstyle",
+ "pytest-raises",
+ "pytest-xdist",
+]
+
+# Install for development
+dev = ["modopt[docs,lint,release,test]"]
+
+# Install with all possible dependencies
+all = ["modopt[dev,extra]"]
+
+[project.urls]
+Source = "https://github.com/CEA-COSMIC/modopt"
+Documentation = "https://cea-cosmic.github.io/ModOpt/"
+Tracker = "https://github.com/CEA-COSMIC/modopt/issues"
+
+[tool.black]
+line-length = 88
+
+[tool.pydocstyle]
+convention = "numpy"
+
+[tool.pytest.ini_options]
+addopts = [
+ "--verbose",
+ "--black",
+ "--emoji",
+ "--pydocstyle",
+ "--cov=modopt",
+ "--cov-report=term-missing",
+ "--cov-report=xml",
+ "--junitxml=pytest.xml",
+]
+norecursedirs = ["tests/test_helpers"]
+testpaths = ["modopt"]
+
+[tool.setuptools.dynamic]
+dependencies = {file = ["requirements.txt"]}
+version = {attr = "modopt.__version__"}
diff --git a/requirements.txt b/requirements.txt
index 1f44de13..0b25c79f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,3 @@
-importlib_metadata>=3.7.0
numpy>=1.19.5
scipy>=1.5.4
tqdm>=4.64.0
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 8d8e821b..00000000
--- a/setup.cfg
+++ /dev/null
@@ -1,95 +0,0 @@
-[aliases]
-test=pytest
-
-[metadata]
-description_file = README.rst
-
-[darglint]
-docstring_style = numpy
-strictness = short
-
-[flake8]
-ignore =
- D107, #Justification: Don't need docstring for __init__ in numpydoc style
- RST304, #Justification: Need to use :cite: role for citations
- RST210, #Justification: RST210, RST213 Inconsistent with numpydoc
- RST213, # documentation for handling *args and **kwargs
- W503, #Justification: Have to choose one multiline operator format
- WPS202, #Todo: Rethink module size, possibly split large modules
- WPS337, #Todo: Consider simplifying multiline conditions.
- WPS338, #Todo: Consider changing method order
- WPS403, #Todo: Rethink no cover lines
- WPS421, #Todo: Review need for print statements
- WPS432, #Justification: Mathematical codes require "magic numbers"
- WPS433, #Todo: Rethink conditional imports
- WPS463, #Todo: Rename get_ methods
- WPS615, #Todo: Rename get_ methods
-per-file-ignores =
- #Justification: Needed for keeping package version and current API
- *__init__.py*: F401,F403,WPS347,WPS410,WPS412
- #Todo: Rethink conditional imports
- #Todo: How can we bypass mutable constants?
- modopt/base/backend.py: WPS229, WPS420, WPS407
- #Todo: Rethink conditional imports
- modopt/base/observable.py: WPS420,WPS604
- #Todo: Check string for log formatting
- modopt/interface/log.py: WPS323
- #Todo: Rethink conditional imports
- modopt/math/convolve.py: WPS301,WPS420
- #Todo: Rethink conditional imports
- modopt/math/matrix.py: WPS420
- #Todo: import has bad parenthesis
- modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410
- #Todo: x is a too short name.
- modopt/opt/algorithms/forward_backward.py: WPS111
- #Todo: Check need for del statement
- modopt/opt/algorithms/primal_dual.py: WPS111, WPS420
- #multiline parameters bug with tuples
- modopt/opt/algorithms/gradient_descent.py: WPS111, WPS420, WPS317
- #Todo: Consider changing costObj name
- modopt/opt/cost.py: N801,
- #Todo:
- # - Rethink subscript slice assignment
- # - Reduce complexity of KSupportNorm
- # - Check bitwise operations
- modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508
- #Todo: Consider changing cwbReweight name
- modopt/opt/reweight.py: N801
- #Justification: Needed to import matplotlib.pyplot
- modopt/plot/cost_plot.py: N802,WPS301
- #Todo: Investigate possible bug in find_n_pc function
- #Todo: Investigate darglint error
- modopt/signal/svd.py: WPS345, DAR000
- #Todo: Check security of using system executable call
- modopt/signal/wavelet.py: S404,S603
- #Todo: Clean up tests
- modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604
- #Todo: Import has bad parenthesis
- modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301
-#WPS Settings
-max-arguments = 25
-max-attributes = 40
-max-cognitive-score = 20
-max-function-expressions = 20
-max-line-complexity = 30
-max-local-variables = 10
-max-methods = 20
-max-module-expressions = 20
-max-string-usages = 20
-max-raises = 5
-
-[tool:pytest]
-norecursedirs=tests/test_helpers
-testpaths =
- modopt
-addopts =
- --verbose
- --cov=modopt
- --cov-report=term-missing
- --cov-report=xml
- --junitxml=pytest.xml
- --pydocstyle
-
-[pydocstyle]
-convention=numpy
-add-ignore=D107
diff --git a/setup.py b/setup.py
deleted file mode 100644
index c95e5984..00000000
--- a/setup.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#! /usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from setuptools import setup, find_packages
-import os
-
-# Set the package release version
-major = 1
-minor = 6
-patch = 1
-
-# Set the package details
-name = 'modopt'
-version = '.'.join(str(value) for value in (major, minor, patch))
-author = 'Samuel Farrens'
-email = 'samuel.farrens@cea.fr'
-gh_user = 'cea-cosmic'
-url = 'https://github.com/{0}/{1}'.format(gh_user, name)
-description = 'Modular Optimisation tools for soliving inverse problems.'
-license = 'MIT'
-
-# Set the package classifiers
-python_versions_supported = ['3.7', '3.8', '3.9', '3.10', '3.11']
-os_platforms_supported = ['Unix', 'MacOS']
-
-lc_str = 'License :: OSI Approved :: {0} License'
-ln_str = 'Programming Language :: Python'
-py_str = 'Programming Language :: Python :: {0}'
-os_str = 'Operating System :: {0}'
-
-classifiers = (
- [lc_str.format(license)]
- + [ln_str]
- + [py_str.format(ver) for ver in python_versions_supported]
- + [os_str.format(ops) for ops in os_platforms_supported]
-)
-
-# Source package description from README.md
-this_directory = os.path.abspath(os.path.dirname(__file__))
-with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
- long_description = f.read()
-
-# Source package requirements from requirements.txt
-with open('requirements.txt') as open_file:
- install_requires = open_file.read()
-
-# Source test requirements from develop.txt
-with open('develop.txt') as open_file:
- tests_require = open_file.read()
-
-# Source doc requirements from docs/requirements.txt
-with open('docs/requirements.txt') as open_file:
- docs_require = open_file.read()
-
-
-setup(
- name=name,
- author=author,
- author_email=email,
- version=version,
- license=license,
- url=url,
- description=description,
- long_description=long_description,
- long_description_content_type='text/markdown',
- packages=find_packages(),
- install_requires=install_requires,
- python_requires='>=3.6',
- setup_requires=['pytest-runner'],
- tests_require=tests_require,
- extras_require={'develop': tests_require + docs_require},
- classifiers=classifiers,
-)
From bdba31f831f5e64df3651b2203767e6d66bbc672 Mon Sep 17 00:00:00 2001
From: Samuel Farrens
Date: Mon, 20 Mar 2023 17:07:08 +0100
Subject: [PATCH 2/3] Update modopt/signal/svd.py
Co-authored-by: Pierre-Antoine Comby
---
modopt/signal/svd.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py
index cc204817..7d94e89d 100644
--- a/modopt/signal/svd.py
+++ b/modopt/signal/svd.py
@@ -170,7 +170,6 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"):
# Find the required number of principal components if not specified.
if isinstance(n_pc, type(None)):
n_pc = find_n_pc(u_vec, factor=0.1)
- print("xxxx", n_pc, u_vec)
# If the number of PCs is too large use all of the singular values.
if (isinstance(n_pc, int) and n_pc >= s_values.size) or (
From 5783bd8517cca96a59125a7dfbaf3aa32bba0328 Mon Sep 17 00:00:00 2001
From: Samuel Farrens
Date: Mon, 20 Mar 2023 17:10:00 +0100
Subject: [PATCH 3/3] Update pyproject.toml
Co-authored-by: Pierre-Antoine Comby
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index ea25e887..a33d3018 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,7 +10,7 @@ maintainers = [
]
description = 'Modular Optimisation tools for soliving inverse problems.'
dynamic = ["dependencies", "version"]
-keywords = ["Image processing", "Wavelets", "Sparsity", "MRI", "Astronomy", "Electron tomography"]
+keywords = ["Image processing", "Wavelets", "Sparsity", "Convex Optimisation", "Proximal Operator"]
license = { "file" = "LICENSE.txt" }
classifiers = [
"Development Status :: 5 - Production/Stable",