sd_hijack_utils.py 1.3 KB

12345678910111213141516171819202122232425262728
  1. import importlib
  2. class CondFunc:
  3. def __new__(cls, orig_func, sub_func, cond_func):
  4. self = super(CondFunc, cls).__new__(cls)
  5. if isinstance(orig_func, str):
  6. func_path = orig_func.split('.')
  7. for i in range(len(func_path)-1, -1, -1):
  8. try:
  9. resolved_obj = importlib.import_module('.'.join(func_path[:i]))
  10. break
  11. except ImportError:
  12. pass
  13. for attr_name in func_path[i:-1]:
  14. resolved_obj = getattr(resolved_obj, attr_name)
  15. orig_func = getattr(resolved_obj, func_path[-1])
  16. setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
  17. self.__init__(orig_func, sub_func, cond_func)
  18. return lambda *args, **kwargs: self(*args, **kwargs)
  19. def __init__(self, orig_func, sub_func, cond_func):
  20. self.__orig_func = orig_func
  21. self.__sub_func = sub_func
  22. self.__cond_func = cond_func
  23. def __call__(self, *args, **kwargs):
  24. if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
  25. return self.__sub_func(self.__orig_func, *args, **kwargs)
  26. else:
  27. return self.__orig_func(*args, **kwargs)