From 159bb63229c3d315c25843d5b4a4e31ec620202f Mon Sep 17 00:00:00 2001 From: Pete Peterson Date: Thu, 19 Dec 2019 09:17:43 -0500 Subject: [PATCH] Move PeakFitFactory into correct subpackage refs #219 --- pyrs/core/pyrscore.py | 16 +++++----------- pyrs/peaks/__init__.py | 6 +++++- pyrs/peaks/mantid_fit_peak.py | 10 ++++++---- pyrs/peaks/peak_collection.py | 2 ++ pyrs/peaks/peak_fit_engine.py | 2 ++ pyrs/{core => peaks}/peak_fit_factory.py | 18 +++++++++--------- pyrs/peaks/scipypeakfitengine.py | 4 ++++ tests/integration/test_peak_fitting.py | 6 +++--- tests/unit/test_peak_fit_engine.py | 2 +- 9 files changed, 37 insertions(+), 29 deletions(-) rename pyrs/{core => peaks}/peak_fit_factory.py (60%) diff --git a/pyrs/core/pyrscore.py b/pyrs/core/pyrscore.py index ad0652aab..125e26958 100644 --- a/pyrs/core/pyrscore.py +++ b/pyrs/core/pyrscore.py @@ -2,7 +2,7 @@ from pyrs.utilities import checkdatatypes from pyrs.core import instrument_geometry from pyrs.utilities import file_util -from pyrs.core import peak_fit_factory +from pyrs.peaks import PeakFitEngineFactory, SupportedPeakProfiles, SupportedBackgroundTypes from pyrs.utilities.rs_project_file import HidraConstants, HidraProjectFile, HidraProjectFileMode from pyrs.core import strain_stress_calculator from pyrs.core import reduction_manager @@ -10,9 +10,6 @@ import os import numpy -# Define Constants -SUPPORTED_PEAK_TYPES = ['PseudoVoigt', 'Gaussian', 'Voigt'] # 'Lorentzian': No a profile of HB2B - class PyRsCore(object): """ @@ -108,8 +105,7 @@ def init_peak_fit_engine(self, fit_tag): # get workspace workspace = self.reduction_service.get_hidra_workspace(fit_tag) # create a controller from factory - self._peak_fitting_dict[fit_tag] = peak_fit_factory.PeakFitEngineFactory.getInstance('Mantid')(workspace, - None) + self._peak_fitting_dict[fit_tag] = PeakFitEngineFactory.getInstance('Mantid')(workspace, None) # set wave length: TODO - #81+ - shall be a way to use calibrated or non-calibrated wave_length_dict = workspace.get_wavelength(calibrated=False, throw_if_not_set=False) if wave_length_dict is not None: @@ -154,10 +150,8 @@ def fit_peaks(self, project_name="", # Check Inputs checkdatatypes.check_dict('Peak fitting (information) parameters', peaks_fitting_setup) - checkdatatypes.check_string_variable('Peak type', peak_type, - peak_fit_factory.SupportedPeakProfiles) - checkdatatypes.check_string_variable('Background type', background_type, - peak_fit_factory.SupportedBackgroundTypes) + checkdatatypes.check_string_variable('Peak type', peak_type, SupportedPeakProfiles) + checkdatatypes.check_string_variable('Background type', background_type, SupportedBackgroundTypes) # Deal with sub runs if sub_run_list is None: @@ -647,4 +641,4 @@ def supported_peak_types(self): list of supported peaks' types for fitting :return: """ - return SUPPORTED_PEAK_TYPES[:] + return SupportedPeakProfiles[:] diff --git a/pyrs/peaks/__init__.py b/pyrs/peaks/__init__.py index d99b82b60..e388566e0 100644 --- a/pyrs/peaks/__init__.py +++ b/pyrs/peaks/__init__.py @@ -1,3 +1,7 @@ +# flake8: noqa from __future__ import (absolute_import, division, print_function) # python3 compatibility -from .peak_collection import PeakCollection +from .peak_collection import * +from .peak_fit_factory import * + +__all__ = peak_collection.__all__ + peak_fit_factory.__all__ diff --git a/pyrs/peaks/mantid_fit_peak.py b/pyrs/peaks/mantid_fit_peak.py index 6fdd45f61..45cf43ee9 100644 --- a/pyrs/peaks/mantid_fit_peak.py +++ b/pyrs/peaks/mantid_fit_peak.py @@ -8,6 +8,7 @@ from mantid.api import AnalysisDataService from mantid.simpleapi import CreateWorkspace, FitPeaks +__all__ = ['MantidPeakFitEngine'] DEBUG = False # Flag for debugging mode @@ -122,25 +123,26 @@ def _set_default_peak_params_value(self, peak_function_name, peak_range): # Make the difference between peak profiles if peak_function_name == 'Gaussian': # Gaussian - peak_param_names = '{}, {}'.format('Height', 'Sigma', 'A0') + peak_param_names = ','.join([str(value) for value in ['Height', 'Sigma', 'A0']]) # sigma instrument_sigma = Gaussian.cal_sigma(hidra_fwhm) # set value - peak_param_values = "{}, {}".format(max_estimated_height, instrument_sigma, flat_bkgd) + peak_param_values = ','.join([str(value) for value in (max_estimated_height, instrument_sigma, flat_bkgd)]) elif peak_function_name == 'PseudoVoigt': # Pseudo-voig default_mixing = 0.6 - peak_param_names = '{}, {}, {}'.format('Mixing', 'Intensity', 'FWHM', 'A0') + peak_param_names = ','.join([str(value) for value in ('Mixing', 'Intensity', 'FWHM', 'A0')]) # intensity max_intensity = PseudoVoigt.cal_intensity(max_estimated_height, hidra_fwhm, default_mixing) # set values - peak_param_values = "{}, {}, {}".format(default_mixing, max_intensity, hidra_fwhm, flat_bkgds) + peak_param_values = ','.join([str(value) for value in (default_mixing, max_intensity, + hidra_fwhm, flat_bkgds)]) else: # Non-supported case diff --git a/pyrs/peaks/peak_collection.py b/pyrs/peaks/peak_collection.py index dbb401dc3..ac2a0f5df 100644 --- a/pyrs/peaks/peak_collection.py +++ b/pyrs/peaks/peak_collection.py @@ -5,6 +5,8 @@ from pyrs.utilities import checkdatatypes from pyrs.core.peak_profile_utility import get_effective_parameters_converter, PeakShape, BackgroundFunction +__all__ = ['PeakCollection'] + class PeakCollection(object): """ diff --git a/pyrs/peaks/peak_fit_engine.py b/pyrs/peaks/peak_fit_engine.py index 7290386c8..d155e6116 100644 --- a/pyrs/peaks/peak_fit_engine.py +++ b/pyrs/peaks/peak_fit_engine.py @@ -6,6 +6,8 @@ from pyrs.core.peak_profile_utility import PeakShape from pyrs.utilities import checkdatatypes +__all__ = ['PeakFitEngine'] + class PeakFitEngine(object): """ diff --git a/pyrs/core/peak_fit_factory.py b/pyrs/peaks/peak_fit_factory.py similarity index 60% rename from pyrs/core/peak_fit_factory.py rename to pyrs/peaks/peak_fit_factory.py index 2b239ca05..d35ae443e 100644 --- a/pyrs/core/peak_fit_factory.py +++ b/pyrs/peaks/peak_fit_factory.py @@ -1,27 +1,27 @@ # Peak fitting engine -from pyrs.peaks.mantid_fit_peak import MantidPeakFitEngine from pyrs.utilities import checkdatatypes SupportedPeakProfiles = ['Gaussian', 'PseudoVoigt', 'Voigt'] SupportedBackgroundTypes = ['Flat', 'Linear', 'Quadratic'] +__all__ = ['PeakFitEngineFactory', 'SupportedPeakProfiles', 'SupportedBackgroundTypes'] + class PeakFitEngineFactory(object): """ Peak fitting engine factory """ @staticmethod - def getInstance(engine_name): + def getInstance(name): """ Get instance of Peak fitting engine - :param engine_name: - :return: """ - checkdatatypes.check_string_variable('Peak fitting engine', engine_name, ['Mantid', 'PyRS']) + checkdatatypes.check_string_variable('Peak fitting engine', name, ['Mantid', 'PyRS']) + + # this must be here for now to stop circular imports + from .mantid_fit_peak import MantidPeakFitEngine - if engine_name == 'Mantid': - engine_class = MantidPeakFitEngine + if name == 'Mantid': + return MantidPeakFitEngine else: raise RuntimeError('Implement general scipy peak fitting engine') - - return engine_class diff --git a/pyrs/peaks/scipypeakfitengine.py b/pyrs/peaks/scipypeakfitengine.py index a7cfc1404..b5991537e 100644 --- a/pyrs/peaks/scipypeakfitengine.py +++ b/pyrs/peaks/scipypeakfitengine.py @@ -3,6 +3,8 @@ import numpy as np from pyrs.utilities import checkdatatypes +__all__ = ['ScipyPeakFitEngine'] + class ScipyPeakFitEngine(PeakFitEngine): """peak fitting engine class for mantid @@ -72,6 +74,7 @@ def calculate_peak(X, Data, TTH, peak_function_name, background_function_name, R else: return Data - model_y + # TODO signature doesn't match base class def fit_peaks(self, peak_function_name, background_function_name, scan_index=None): """ fit peaks @@ -149,6 +152,7 @@ def fit_peaks(self, peak_function_name, background_function_name, scan_index=Non return + # TODO arguments don't match base class def calculate_fitted_peaks(self, scan_index): """ get the calculated peak's value diff --git a/tests/integration/test_peak_fitting.py b/tests/integration/test_peak_fitting.py index 7c702d61d..95662f45b 100644 --- a/tests/integration/test_peak_fitting.py +++ b/tests/integration/test_peak_fitting.py @@ -7,7 +7,7 @@ from pyrs.core.summary_generator import SummaryGenerator from pyrs.dataobjects import SampleLogs from pyrs.utilities.rs_project_file import HidraProjectFile -from pyrs.core import peak_fit_factory +from pyrs.peaks import PeakFitEngineFactory import h5py from pyrs.core import peak_profile_utility from matplotlib import pyplot as plt @@ -225,7 +225,7 @@ def test_retrieve_fit_metadata(source_project_file, output_project_file, peak_ty # Set peak fitting engine # create a controller from factory - fit_engine = peak_fit_factory.PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None) + fit_engine = PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None) # Fit peak fit_engine.fit_multiple_peaks(sub_run_range=(None, None), # default is all sub runs @@ -348,7 +348,7 @@ def test_improve_quality(): # Set peak fitting engine # create a controller from factory - fit_engine = peak_fit_factory.PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None) + fit_engine = PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None) peak_type = 'Gaussian' diff --git a/tests/unit/test_peak_fit_engine.py b/tests/unit/test_peak_fit_engine.py index 0fbb014b0..3fb8faf27 100644 --- a/tests/unit/test_peak_fit_engine.py +++ b/tests/unit/test_peak_fit_engine.py @@ -1,5 +1,5 @@ import numpy as np -from pyrs.core.peak_fit_factory import PeakFitEngineFactory +from pyrs.peaks import PeakFitEngineFactory from pyrs.core.workspaces import HidraWorkspace from pyrs.core.peak_profile_utility import pseudo_voigt, PeakShape, BackgroundFunction from pyrs.core.peak_profile_utility import Gaussian, PseudoVoigt