syrte
11/15/2018 - 4:55 PM

ModelParam.py

class ModelParam:
    """
    p0 = ModelParam({'a': 1, 'b':2})
    p1 = p0.free('a', 'b').fix(b=1, c=2)
    """

    def __init__(self, param_fixed, names_free=[]):
        """
        param_fixed: dict
        names_free: list
        """
        self.param_fixed = {**param_fixed}
        for name in names_free:
            self.param_fixed.pop(name, None)
        self.names_free = [*names_free]

    def free(self, *names_free):
        if [name for name in names_free if name in self.names_free]:
            raise ValueError('duplicated params')
        return ModelParam(self.param_fixed, [*self.names_free, *names_free])

    def fix(self, **param_fixed):
        return ModelParam({**self.param_fixed, **param_fixed}, self.names_free)

    def param_list_to_dict(self, param_list):
        param_free = dict(zip(self.names_free, param_list))
        return {**self.param_fixed, **param_free}