Skip to content

base_param

base_param

BaseParam

Source code in federatedml/param/base_param.py
class BaseParam(metaclass=_StaticDefaultMeta):
    def __init__(self):
        pass

    def set_name(self, name: str):
        self._name = name
        return self

    def check(self):
        raise NotImplementedError("Parameter Object should be checked.")

    @classmethod
    def _get_or_init_deprecated_params_set(cls):
        if not hasattr(cls, _DEPRECATED_PARAMS):
            setattr(cls, _DEPRECATED_PARAMS, set())
        return getattr(cls, _DEPRECATED_PARAMS)

    def _get_or_init_feeded_deprecated_params_set(self, conf=None):
        if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
            if conf is None:
                setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
            else:
                setattr(
                    self,
                    _FEEDED_DEPRECATED_PARAMS,
                    set(conf[_FEEDED_DEPRECATED_PARAMS]),
                )
        return getattr(self, _FEEDED_DEPRECATED_PARAMS)

    def _get_or_init_user_feeded_params_set(self, conf=None):
        if not hasattr(self, _USER_FEEDED_PARAMS):
            if conf is None:
                setattr(self, _USER_FEEDED_PARAMS, set())
            else:
                setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
        return getattr(self, _USER_FEEDED_PARAMS)

    def get_user_feeded(self):
        return self._get_or_init_user_feeded_params_set()

    def get_feeded_deprecated_params(self):
        return self._get_or_init_feeded_deprecated_params_set()

    @property
    def _deprecated_params_set(self):
        return {name: True for name in self.get_feeded_deprecated_params()}

    def as_dict(self):
        def _recursive_convert_obj_to_dict(obj):
            ret_dict = {}
            for attr_name in list(obj.__dict__):
                # get attr
                attr = getattr(obj, attr_name)
                if attr and type(attr).__name__ not in dir(builtins):
                    ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
                else:
                    ret_dict[attr_name] = attr

            return ret_dict

        return _recursive_convert_obj_to_dict(self)

    def update(self, conf, allow_redundant=False):
        update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
        if update_from_raw_conf:
            deprecated_params_set = self._get_or_init_deprecated_params_set()
            feeded_deprecated_params_set = (
                self._get_or_init_feeded_deprecated_params_set()
            )
            user_feeded_params_set = self._get_or_init_user_feeded_params_set()
            setattr(self, _IS_RAW_CONF, False)
        else:
            feeded_deprecated_params_set = (
                self._get_or_init_feeded_deprecated_params_set(conf)
            )
            user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)

        def _recursive_update_param(param, config, depth, prefix):
            if depth > consts.PARAM_MAXDEPTH:
                raise ValueError("Param define nesting too deep!!!, can not parse it")

            inst_variables = param.__dict__
            redundant_attrs = []
            for config_key, config_value in config.items():
                # redundant attr
                if config_key not in inst_variables:
                    if not update_from_raw_conf and config_key.startswith("_"):
                        setattr(param, config_key, config_value)
                    else:
                        redundant_attrs.append(config_key)
                    continue

                full_config_key = f"{prefix}{config_key}"

                if update_from_raw_conf:
                    # add user feeded params
                    user_feeded_params_set.add(full_config_key)

                    # update user feeded deprecated param set
                    if full_config_key in deprecated_params_set:
                        feeded_deprecated_params_set.add(full_config_key)

                # supported attr
                attr = getattr(param, config_key)
                if type(attr).__name__ in dir(builtins) or attr is None:
                    setattr(param, config_key, config_value)

                else:
                    # recursive set obj attr
                    sub_params = _recursive_update_param(
                        attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
                    )
                    setattr(param, config_key, sub_params)

            if not allow_redundant and redundant_attrs:
                raise ValueError(
                    f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
                )

            return param

        return _recursive_update_param(param=self, config=conf, depth=0, prefix="")

    def extract_not_builtin(self):
        def _get_not_builtin_types(obj):
            ret_dict = {}
            for variable in obj.__dict__:
                attr = getattr(obj, variable)
                if attr and type(attr).__name__ not in dir(builtins):
                    ret_dict[variable] = _get_not_builtin_types(attr)

            return ret_dict

        return _get_not_builtin_types(self)

    def validate(self):
        self.builtin_types = dir(builtins)
        self.func = {
            "ge": self._greater_equal_than,
            "le": self._less_equal_than,
            "in": self._in,
            "not_in": self._not_in,
            "range": self._range,
        }
        home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
        param_validation_path_prefix = home_dir + "/param_validation/"

        param_name = type(self).__name__
        param_validation_path = "/".join(
            [param_validation_path_prefix, param_name + ".json"]
        )

        validation_json = None

        try:
            with open(param_validation_path, "r") as fin:
                validation_json = json.loads(fin.read())
        except BaseException:
            return

        self._validate_param(self, validation_json)

    def _validate_param(self, param_obj, validation_json):
        default_section = type(param_obj).__name__
        var_list = param_obj.__dict__

        for variable in var_list:
            attr = getattr(param_obj, variable)

            if type(attr).__name__ in self.builtin_types or attr is None:
                if variable not in validation_json:
                    continue

                validation_dict = validation_json[default_section][variable]
                value = getattr(param_obj, variable)
                value_legal = False

                for op_type in validation_dict:
                    if self.func[op_type](value, validation_dict[op_type]):
                        value_legal = True
                        break

                if not value_legal:
                    raise ValueError(
                        "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
                            variable, value
                        )
                    )

            elif variable in validation_json:
                self._validate_param(attr, validation_json)

    @staticmethod
    def check_string(param, descr):
        if type(param).__name__ not in ["str"]:
            raise ValueError(
                descr + " {} not supported, should be string type".format(param)
            )

    @staticmethod
    def check_positive_integer(param, descr):
        if type(param).__name__ not in ["int", "long"] or param <= 0:
            raise ValueError(
                descr + " {} not supported, should be positive integer".format(param)
            )

    @staticmethod
    def check_positive_number(param, descr):
        if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
            raise ValueError(
                descr + " {} not supported, should be positive numeric".format(param)
            )

    @staticmethod
    def check_nonnegative_number(param, descr):
        if type(param).__name__ not in ["float", "int", "long"] or param < 0:
            raise ValueError(
                descr
                + " {} not supported, should be non-negative numeric".format(param)
            )

    @staticmethod
    def check_decimal_float(param, descr):
        if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
            raise ValueError(
                descr
                + " {} not supported, should be a float number in range [0, 1]".format(
                    param
                )
            )

    @staticmethod
    def check_boolean(param, descr):
        if type(param).__name__ != "bool":
            raise ValueError(
                descr + " {} not supported, should be bool type".format(param)
            )

    @staticmethod
    def check_open_unit_interval(param, descr):
        if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
            raise ValueError(
                descr + " should be a numeric number between 0 and 1 exclusively"
            )

    @staticmethod
    def check_valid_value(param, descr, valid_values):
        if param not in valid_values:
            raise ValueError(
                descr
                + " {} is not supported, it should be in {}".format(param, valid_values)
            )

    @staticmethod
    def check_defined_type(param, descr, types):
        if type(param).__name__ not in types:
            raise ValueError(
                descr + " {} not supported, should be one of {}".format(param, types)
            )

    @staticmethod
    def check_and_change_lower(param, valid_list, descr=""):
        if type(param).__name__ != "str":
            raise ValueError(
                descr
                + " {} not supported, should be one of {}".format(param, valid_list)
            )

        lower_param = param.lower()
        if lower_param in valid_list:
            return lower_param
        else:
            raise ValueError(
                descr
                + " {} not supported, should be one of {}".format(param, valid_list)
            )

    @staticmethod
    def _greater_equal_than(value, limit):
        return value >= limit - consts.FLOAT_ZERO

    @staticmethod
    def _less_equal_than(value, limit):
        return value <= limit + consts.FLOAT_ZERO

    @staticmethod
    def _range(value, ranges):
        in_range = False
        for left_limit, right_limit in ranges:
            if (
                left_limit - consts.FLOAT_ZERO
                <= value
                <= right_limit + consts.FLOAT_ZERO
            ):
                in_range = True
                break

        return in_range

    @staticmethod
    def _in(value, right_value_list):
        return value in right_value_list

    @staticmethod
    def _not_in(value, wrong_value_list):
        return value not in wrong_value_list

    def _warn_deprecated_param(self, param_name, descr):
        if self._deprecated_params_set.get(param_name):
            LOGGER.warning(
                f"{descr} {param_name} is deprecated and ignored in this version."
            )

    def _warn_to_deprecate_param(self, param_name, descr, new_param):
        if self._deprecated_params_set.get(param_name):
            LOGGER.warning(
                f"{descr} {param_name} will be deprecated in future release; "
                f"please use {new_param} instead."
            )
            return True
        return False
__init__(self) special
Source code in federatedml/param/base_param.py
def __init__(self):
    pass
set_name(self, name)
Source code in federatedml/param/base_param.py
def set_name(self, name: str):
    self._name = name
    return self
check(self)
Source code in federatedml/param/base_param.py
def check(self):
    raise NotImplementedError("Parameter Object should be checked.")
get_user_feeded(self)
Source code in federatedml/param/base_param.py
def get_user_feeded(self):
    return self._get_or_init_user_feeded_params_set()
get_feeded_deprecated_params(self)
Source code in federatedml/param/base_param.py
def get_feeded_deprecated_params(self):
    return self._get_or_init_feeded_deprecated_params_set()
as_dict(self)
Source code in federatedml/param/base_param.py
def as_dict(self):
    def _recursive_convert_obj_to_dict(obj):
        ret_dict = {}
        for attr_name in list(obj.__dict__):
            # get attr
            attr = getattr(obj, attr_name)
            if attr and type(attr).__name__ not in dir(builtins):
                ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
            else:
                ret_dict[attr_name] = attr

        return ret_dict

    return _recursive_convert_obj_to_dict(self)
update(self, conf, allow_redundant=False)
Source code in federatedml/param/base_param.py
def update(self, conf, allow_redundant=False):
    update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
    if update_from_raw_conf:
        deprecated_params_set = self._get_or_init_deprecated_params_set()
        feeded_deprecated_params_set = (
            self._get_or_init_feeded_deprecated_params_set()
        )
        user_feeded_params_set = self._get_or_init_user_feeded_params_set()
        setattr(self, _IS_RAW_CONF, False)
    else:
        feeded_deprecated_params_set = (
            self._get_or_init_feeded_deprecated_params_set(conf)
        )
        user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)

    def _recursive_update_param(param, config, depth, prefix):
        if depth > consts.PARAM_MAXDEPTH:
            raise ValueError("Param define nesting too deep!!!, can not parse it")

        inst_variables = param.__dict__
        redundant_attrs = []
        for config_key, config_value in config.items():
            # redundant attr
            if config_key not in inst_variables:
                if not update_from_raw_conf and config_key.startswith("_"):
                    setattr(param, config_key, config_value)
                else:
                    redundant_attrs.append(config_key)
                continue

            full_config_key = f"{prefix}{config_key}"

            if update_from_raw_conf:
                # add user feeded params
                user_feeded_params_set.add(full_config_key)

                # update user feeded deprecated param set
                if full_config_key in deprecated_params_set:
                    feeded_deprecated_params_set.add(full_config_key)

            # supported attr
            attr = getattr(param, config_key)
            if type(attr).__name__ in dir(builtins) or attr is None:
                setattr(param, config_key, config_value)

            else:
                # recursive set obj attr
                sub_params = _recursive_update_param(
                    attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
                )
                setattr(param, config_key, sub_params)

        if not allow_redundant and redundant_attrs:
            raise ValueError(
                f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
            )

        return param

    return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
extract_not_builtin(self)
Source code in federatedml/param/base_param.py
def extract_not_builtin(self):
    def _get_not_builtin_types(obj):
        ret_dict = {}
        for variable in obj.__dict__:
            attr = getattr(obj, variable)
            if attr and type(attr).__name__ not in dir(builtins):
                ret_dict[variable] = _get_not_builtin_types(attr)

        return ret_dict

    return _get_not_builtin_types(self)
validate(self)
Source code in federatedml/param/base_param.py
def validate(self):
    self.builtin_types = dir(builtins)
    self.func = {
        "ge": self._greater_equal_than,
        "le": self._less_equal_than,
        "in": self._in,
        "not_in": self._not_in,
        "range": self._range,
    }
    home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
    param_validation_path_prefix = home_dir + "/param_validation/"

    param_name = type(self).__name__
    param_validation_path = "/".join(
        [param_validation_path_prefix, param_name + ".json"]
    )

    validation_json = None

    try:
        with open(param_validation_path, "r") as fin:
            validation_json = json.loads(fin.read())
    except BaseException:
        return

    self._validate_param(self, validation_json)
check_string(param, descr) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_string(param, descr):
    if type(param).__name__ not in ["str"]:
        raise ValueError(
            descr + " {} not supported, should be string type".format(param)
        )
check_positive_integer(param, descr) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_positive_integer(param, descr):
    if type(param).__name__ not in ["int", "long"] or param <= 0:
        raise ValueError(
            descr + " {} not supported, should be positive integer".format(param)
        )
check_positive_number(param, descr) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_positive_number(param, descr):
    if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
        raise ValueError(
            descr + " {} not supported, should be positive numeric".format(param)
        )
check_nonnegative_number(param, descr) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_nonnegative_number(param, descr):
    if type(param).__name__ not in ["float", "int", "long"] or param < 0:
        raise ValueError(
            descr
            + " {} not supported, should be non-negative numeric".format(param)
        )
check_decimal_float(param, descr) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_decimal_float(param, descr):
    if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
        raise ValueError(
            descr
            + " {} not supported, should be a float number in range [0, 1]".format(
                param
            )
        )
check_boolean(param, descr) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_boolean(param, descr):
    if type(param).__name__ != "bool":
        raise ValueError(
            descr + " {} not supported, should be bool type".format(param)
        )
check_open_unit_interval(param, descr) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_open_unit_interval(param, descr):
    if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
        raise ValueError(
            descr + " should be a numeric number between 0 and 1 exclusively"
        )
check_valid_value(param, descr, valid_values) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_valid_value(param, descr, valid_values):
    if param not in valid_values:
        raise ValueError(
            descr
            + " {} is not supported, it should be in {}".format(param, valid_values)
        )
check_defined_type(param, descr, types) staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_defined_type(param, descr, types):
    if type(param).__name__ not in types:
        raise ValueError(
            descr + " {} not supported, should be one of {}".format(param, types)
        )
check_and_change_lower(param, valid_list, descr='') staticmethod
Source code in federatedml/param/base_param.py
@staticmethod
def check_and_change_lower(param, valid_list, descr=""):
    if type(param).__name__ != "str":
        raise ValueError(
            descr
            + " {} not supported, should be one of {}".format(param, valid_list)
        )

    lower_param = param.lower()
    if lower_param in valid_list:
        return lower_param
    else:
        raise ValueError(
            descr
            + " {} not supported, should be one of {}".format(param, valid_list)
        )

deprecated_param(*names)

Source code in federatedml/param/base_param.py
def deprecated_param(*names):
    def _decorator(cls: "BaseParam"):
        deprecated = cls._get_or_init_deprecated_params_set()
        for name in names:
            deprecated.add(name)
        return cls

    return _decorator

Last update: 2022-07-07