Source code for rocketpy.stochastic.stochastic_model

"""
Defines the `StochasticModel` class, which is used as a base class for all other
Stochastic classes.
"""

from random import choice

import numpy as np

from rocketpy.mathutils.function import Function
from rocketpy.stochastic.custom_sampler import CustomSampler

from ..tools import get_distribution

# TODO: Stop using assert in production code. Use exceptions instead.
# TODO: Each validation method should have a test case.


[docs] class StochasticModel: """ Base class for all Stochastic classes. This class validates input arguments, saves them as attributes, and generates a dictionary with randomly generated input arguments. See also -------- :ref:`Working with Stochastic Models <stochastic_usage>` Notes ----- Please notice that the methods starting with an underscore are not meant to be called directly by the user. These methods may receive breaking changes without notice, so use them at your own risk. """ # Arguments that are validated only in child classes exception_list = [ "initial_solution", "terminate_on_apogee", "date", "ensemble_member", ]
[docs] def __init__(self, obj, seed=None, **kwargs): """ Initialize the StochasticModel class with validated input arguments. Parameters ---------- obj : object The main object of the class. seed : int, optional Seed for the random number generator. The default is None so that a new ``numpy.random.Generator`` object is created. **kwargs : dict Dictionary of input arguments for the class. Valid argument types include tuples, lists, ints, floats, or None. Arguments will be validated and saved as class attributes in a specific format, which is described in the ":ref:`Working with Stochastic Models <stochastic_usage>`" page. Raises ------ AssertionError If the input arguments do not conform to the specified formats. """ self.obj = obj self.last_rnd_dict = {} self.__stochastic_dict = kwargs self._set_stochastic(seed)
[docs] def _set_stochastic(self, seed=None): """Set the stochastic attributes from the input dictionary. This method is useful to reset or reseed the attributes of the instance. Parameters ---------- seed : int, optional Seed for the random number generator. """ self.__random_number_generator = np.random.default_rng(seed) self.last_rnd_dict = {} # TODO: This code block is too complex. Refactor it. # TODO: Resetting a instance should not require re-validation. for input_name, input_value in self.__stochastic_dict.items(): if input_name not in self.exception_list: attr_value = None if input_value is not None: if "factor" in input_name: attr_value = self._validate_factors( input_name, input_value, seed ) elif input_name not in self.exception_list: if isinstance(input_value, tuple): attr_value = self._validate_tuple(input_name, input_value) elif isinstance(input_value, list): attr_value = self._validate_list(input_name, input_value) elif isinstance(input_value, (int, float)): attr_value = self._validate_scalar(input_name, input_value) elif isinstance(input_value, CustomSampler): attr_value = self._validate_custom_sampler( input_name, input_value, seed ) else: raise AssertionError( f"'{input_name}' must be a tuple, list, int, or float" "or a custom sampler" ) else: attr_value = [getattr(self.obj, input_name)] setattr(self, input_name, attr_value)
def __repr__(self): return f"'{self.__class__.__name__}() object'"
[docs] def _validate_tuple(self, input_name, input_value, getattr=getattr): # pylint: disable=redefined-builtin """ Validate tuple arguments. Parameters ---------- input_name : str Name of the input argument. input_value : tuple Value of the input argument. This is the tuple to be validated. getattr : function Function used to get the attribute value from the object. Returns ------- tuple Validated tuple in the format (nominal value, standard deviation, \ distribution function). Raises ------ AssertionError If the input is not in a valid format. """ assert len(input_value) in [ 2, 3, ], f"'{input_name}': tuple must have length 2 or 3" assert isinstance(input_value[0], (int, float)), ( f"'{input_name}': First item of tuple must be an int or float" ) if len(input_value) == 2: return self._validate_tuple_length_two(input_name, input_value, getattr) if len(input_value) == 3: return self._validate_tuple_length_three(input_name, input_value, getattr)
[docs] def _validate_tuple_length_two(self, input_name, input_value, getattr=getattr): # pylint: disable=redefined-builtin """ Validate tuples with length 2. Parameters ---------- input_name : str Name of the input argument. input_value : tuple Value of the input argument. getattr : function Function to get the attribute value from the object. Returns ------- tuple Validated tuple in the format (nominal value, standard deviation, \ distribution function). Raises ------ AssertionError If the input is not in a valid format. """ assert isinstance(input_value[1], (int, float, str)), ( f"'{input_name}': second item of tuple must be an int, float, or string." ) if isinstance(input_value[1], str): # if second item is a string, then it is assumed that the first item # is the standard deviation, and the second item is the distribution # function. In this case, the nominal value will be taken from the # object passed. dist_func = get_distribution(input_value[1], self.__random_number_generator) return (getattr(self.obj, input_name), input_value[0], dist_func) else: # if second item is an int or float, then it is assumed that the # first item is the nominal value and the second item is the # standard deviation. The distribution function will be set to # "normal". return ( input_value[0], input_value[1], get_distribution("normal", self.__random_number_generator), )
[docs] def _validate_tuple_length_three(self, input_name, input_value, getattr=getattr): # pylint: disable=redefined-builtin,unused-argument """ Validate tuples with length 3. Parameters ---------- input_name : str Name of the input argument. input_value : tuple Value of the input argument. getattr : function Function to get the attribute value from the object. Returns ------- tuple Validated tuple in the format (nominal value, standard deviation, \ distribution function). Raises ------ AssertionError If the input is not in a valid format. """ assert isinstance(input_value[1], (int, float)), ( f"'{input_name}': Second item of a tuple with length 3 must be an " "int or float." ) assert isinstance(input_value[2], str), ( f"'{input_name}': Third item of tuple must be a string containing the " "name of a valid numpy.random distribution function." ) dist_func = get_distribution(input_value[2], self.__random_number_generator) return (input_value[0], input_value[1], dist_func)
[docs] def _validate_list(self, input_name, input_value, getattr=getattr): # pylint: disable=redefined-builtin """ Validate list arguments. Parameters ---------- input_name : str Name of the input argument. input_value : list Value of the input argument. getattr : function Function to get the attribute value from the object. Returns ------- list Validated list. Raises ------ AssertionError If the input is not in a valid format. """ if not input_value: return [getattr(self.obj, input_name)] else: return input_value
[docs] def _validate_scalar(self, input_name, input_value, getattr=getattr): # pylint: disable=redefined-builtin """ Validate scalar arguments. If the input is a scalar, the nominal value will be taken from the object passed, and the standard deviation will be the scalar value. The distribution function will be set to "normal". Parameters ---------- input_name : str Name of the input argument. input_value : float Value of the input argument. getattr : function Function to get the attribute value from the object. Returns ------- tuple Validated tuple in the format (nominal value, standard deviation, \ distribution function). """ return ( getattr(self.obj, input_name), input_value, get_distribution("normal", self.__random_number_generator), )
[docs] def _validate_factors(self, input_name, input_value, seed): """ Validate factor arguments. Parameters ---------- input_name : str Name of the input argument. input_value : tuple or list Value of the input argument. Returns ------- tuple or list Validated tuple or list. Raises ------ AssertionError If the input is not in a valid format. """ attribute_name = input_name.replace("_factor", "") setattr(self, f"_{attribute_name}", getattr(self.obj, attribute_name)) if isinstance(input_value, tuple): return self._validate_tuple_factor(input_name, input_value) elif isinstance(input_value, list): return self._validate_list_factor(input_name, input_value) elif isinstance(input_value, CustomSampler): return self._validate_custom_sampler(input_name, input_value, seed) else: raise AssertionError( f"`{input_name}`: must be either a tuple or listor a custom sampler" )
[docs] def _validate_tuple_factor(self, input_name, factor_tuple): """ Validate tuple factors. Parameters ---------- input_name : str Name of the input argument. factor_tuple : tuple Value of the input argument. Returns ------- tuple Validated tuple. Raises ------ AssertionError If the input is not in a valid format. """ assert len(factor_tuple) in [ 2, 3, ], f"'{input_name}`: Factors tuple must have length 2 or 3" assert all(isinstance(item, (int, float)) for item in factor_tuple[:2]), ( f"'{input_name}`: First and second items of Factors tuple must be " "either an int or float" ) if len(factor_tuple) == 2: return ( factor_tuple[0], factor_tuple[1], get_distribution("normal", self.__random_number_generator), ) elif len(factor_tuple) == 3: assert isinstance(factor_tuple[2], str), ( f"'{input_name}`: Third item of tuple must be a string containing " "the name of a valid numpy.random distribution function" ) dist_func = get_distribution( factor_tuple[2], self.__random_number_generator ) return (factor_tuple[0], factor_tuple[1], dist_func)
[docs] def _validate_list_factor(self, input_name, factor_list): """ Validate list factors. Parameters ---------- input_name : str Name of the input argument. factor_list : list Value of the input argument. Returns ------- list Validated list. Raises ------ AssertionError If the input is not in a valid format. """ assert all(isinstance(item, (int, float)) for item in factor_list), ( f"'{input_name}`: Items in list must be either ints or floats" ) return factor_list
[docs] def _validate_1d_array_like(self, input_name, input_value): """ Validate 1D array-like arguments. Parameters ---------- input_name : str Name of the input argument. input_value : list Value of the input argument. Raises ------ AssertionError If the input is not in a valid format. """ if input_value is not None: error_msg = ( f"`{input_name}` must be a list of path strings, lists " "with shape (n,2), or Functions." ) if not isinstance(input_value, list): raise AssertionError(error_msg) for member in input_value: if isinstance(member, list): if len(np.shape(member)) != 2 or np.shape(member)[1] != 2: raise AssertionError(error_msg) elif not isinstance(member, (str, Function)): raise AssertionError(error_msg)
[docs] def _validate_positive_int_list(self, input_name, input_value): """ Validate lists of positive integers. Parameters ---------- input_name : str Name of the input argument. input_value : list Value of the input argument. Raises ------ AssertionError If the input is not in a valid format. """ if input_value is not None: assert isinstance(input_value, list) and all( isinstance(member, int) and member >= 0 for member in input_value ), f"`{input_name}` must be a list of positive integers"
[docs] def _validate_custom_sampler(self, input_name, sampler, seed=None): """ Validate a custom sampler. Parameters ---------- input_name : str Name of the input argument. sampler : CustomSampler object Custom sampler provided by the user seed : int, optional Seed for the random number generator. The default is None Raises ------ AssertionError If the input is not in a valid format. """ try: sampler.reset_seed(seed) except RuntimeError as e: raise RuntimeError( f"An error occurred in the 'reset_seed' method of {input_name} CustomSampler" ) from e return sampler
[docs] def _validate_airfoil(self, airfoil): """ Validate airfoil input. Parameters ---------- airfoil : list[tuple] List of tuples with two items. Raises ------ AssertionError If the input is not in a valid format. """ # TODO: The _validate_airfoil should be defined in a child class. if airfoil is not None: assert isinstance(airfoil, list) and all( isinstance(member, tuple) for member in airfoil ), "`airfoil` must be a list of tuples" for member in airfoil: assert len(member) == 2, "`airfoil` tuples must have length 2" assert isinstance(member[1], str), ( "`airfoil` tuples must have a string as the second item" ) if isinstance(member[0], list): if len(np.shape(member[0])) != 2 and np.shape(member[0])[1] != 2: raise AssertionError("`airfoil` tuples must have shape (n,2)") elif not isinstance(member[0], str) and not callable(member[0]): raise AssertionError( "`airfoil` tuples must have a string or Function as " "the first item" )
[docs] def dict_generator(self): """ Generate a dictionary with randomly generated input arguments. The last generated dictionary is saved as a class attribute called `last_rnd_dict`. Yields ------ dict Dictionary with the randomly generated input arguments. Notes ----- 1. The dictionary is generated by iterating over the class attributes and: a. If the attribute is a tuple, the value is generated using the\ distribution function specified in the tuple. b. If the attribute is a list, the value is randomly chosen from the list. """ generated_dict = {} for arg, value in self.__dict__.items(): if isinstance(value, tuple): dist_sampler = value[-1] generated_dict[arg] = dist_sampler(value[0], value[1]) elif isinstance(value, list): generated_dict[arg] = choice(value) if value else value elif isinstance(value, CustomSampler): try: generated_dict[arg] = value.sample(n_samples=1)[0] except RuntimeError as e: raise RuntimeError( f"An error occurred in the 'sample' method of {arg} CustomSampler" ) from e self.last_rnd_dict = generated_dict yield generated_dict
# pylint: disable=too-many-statements
[docs] def visualize_attributes(self): """ This method prints a report of the attributes stored in the Stochastic Model object. The report includes the variable name, the nominal value, the standard deviation, and the distribution function used to generate the random attributes. """ def format_attribute(attr, value): if isinstance(value, list): return ( f"\t{attr.ljust(max_str_length)} {value[0]}" if len(value) == 1 else f"\t{attr} {value}" ) elif isinstance(value, tuple): nominal_value, std_dev, dist_func = value if callable(dist_func) and dist_func.__name__ == "uniform": lower_bound = nominal_value upper_bound = std_dev return ( f"\t{attr.ljust(max_str_length)} " f"{lower_bound:.5f}, {upper_bound:.5f} ({dist_func.__name__})" ) else: return ( f"\t{attr.ljust(max_str_length)} " f"{nominal_value:.5f} ± " f"{std_dev:.5f} ({dist_func.__name__})" ) elif isinstance(value, CustomSampler): sampler_name = type(value).__name__ return ( f"\t{attr.ljust(max_str_length)} " f"\t{sampler_name.ljust(max_str_length)} " ) return None attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} max_str_length = max(len(var) for var in attributes) + 2 report = [ f"Reporting the attributes of the `{self.__class__.__name__}` object:" ] # Sorting alphabetically makes the report more readable items = attributes.items() items = sorted(items, key=lambda x: x[0]) to_exclude = ["object", "last_rnd_dict", "exception_list", "parachutes"] items = [item for item in items if item[0] not in to_exclude] constant_attributes = [ attr for attr, val in items if isinstance(val, list) and len(val) == 1 ] tuple_attributes = [attr for attr, val in items if isinstance(val, tuple)] list_attributes = [ attr for attr, val in items if isinstance(val, list) and len(val) > 1 ] custom_attributes = [ attr for attr, val in items if isinstance(val, CustomSampler) ] if constant_attributes: report.append("\nConstant Attributes:") report.extend( format_attribute(attr, attributes[attr]) for attr in constant_attributes ) if tuple_attributes: report.append("\nStochastic Attributes:") report.extend( format_attribute(attr, attributes[attr]) for attr in tuple_attributes ) if list_attributes: report.append("\nStochastic Attributes with choice of values:") report.extend( format_attribute(attr, attributes[attr]) for attr in list_attributes ) if custom_attributes: report.append("\nStochastic Attributes with Custom user samplers:") report.extend( format_attribute(attr, attributes[attr]) for attr in custom_attributes ) print("\n".join(filter(None, report)))