class ModelParam:
"""
p0 = ModelParam({'a': 1, 'b':2})
p1 = p0.free('a', 'b').fix(b=1, c=2)
"""
def __init__(self, param_fixed, names_free=[]):
"""
param_fixed: dict
names_free: list
"""
self.param_fixed = {**param_fixed}
for name in names_free:
self.param_fixed.pop(name, None)
self.names_free = [*names_free]
def free(self, *names_free):
if [name for name in names_free if name in self.names_free]:
raise ValueError('duplicated params')
return ModelParam(self.param_fixed, [*self.names_free, *names_free])
def fix(self, **param_fixed):
return ModelParam({**self.param_fixed, **param_fixed}, self.names_free)
def param_list_to_dict(self, param_list):
param_free = dict(zip(self.names_free, param_list))
return {**self.param_fixed, **param_free}