safe.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # this code is adapted from the script contributed by anon from /h/
  2. import pickle
  3. import collections
  4. import torch
  5. import numpy
  6. import _codecs
  7. import zipfile
  8. import re
  9. # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
  10. from modules import errors
  11. TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
  12. def encode(*args):
  13. out = _codecs.encode(*args)
  14. return out
  15. class RestrictedUnpickler(pickle.Unpickler):
  16. extra_handler = None
  17. def persistent_load(self, saved_id):
  18. assert saved_id[0] == 'storage'
  19. try:
  20. return TypedStorage(_internal=True)
  21. except TypeError:
  22. return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
  23. def find_class(self, module, name):
  24. if self.extra_handler is not None:
  25. res = self.extra_handler(module, name)
  26. if res is not None:
  27. return res
  28. if module == 'collections' and name == 'OrderedDict':
  29. return getattr(collections, name)
  30. if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
  31. return getattr(torch._utils, name)
  32. if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
  33. return getattr(torch, name)
  34. if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
  35. return getattr(torch.nn.modules.container, name)
  36. if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
  37. return getattr(numpy.core.multiarray, name)
  38. if module == 'numpy' and name in ['dtype', 'ndarray']:
  39. return getattr(numpy, name)
  40. if module == '_codecs' and name == 'encode':
  41. return encode
  42. if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
  43. import pytorch_lightning.callbacks
  44. return pytorch_lightning.callbacks.model_checkpoint
  45. if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
  46. import pytorch_lightning.callbacks.model_checkpoint
  47. return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
  48. if module == "__builtin__" and name == 'set':
  49. return set
  50. # Forbid everything else.
  51. raise Exception(f"global '{module}/{name}' is forbidden")
  52. # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
  53. allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
  54. data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
  55. def check_zip_filenames(filename, names):
  56. for name in names:
  57. if allowed_zip_names_re.match(name):
  58. continue
  59. raise Exception(f"bad file inside {filename}: {name}")
  60. def check_pt(filename, extra_handler):
  61. try:
  62. # new pytorch format is a zip file
  63. with zipfile.ZipFile(filename) as z:
  64. check_zip_filenames(filename, z.namelist())
  65. # find filename of data.pkl in zip file: '<directory name>/data.pkl'
  66. data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
  67. if len(data_pkl_filenames) == 0:
  68. raise Exception(f"data.pkl not found in {filename}")
  69. if len(data_pkl_filenames) > 1:
  70. raise Exception(f"Multiple data.pkl found in {filename}")
  71. with z.open(data_pkl_filenames[0]) as file:
  72. unpickler = RestrictedUnpickler(file)
  73. unpickler.extra_handler = extra_handler
  74. unpickler.load()
  75. except zipfile.BadZipfile:
  76. # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
  77. with open(filename, "rb") as file:
  78. unpickler = RestrictedUnpickler(file)
  79. unpickler.extra_handler = extra_handler
  80. for _ in range(5):
  81. unpickler.load()
  82. def load(filename, *args, **kwargs):
  83. return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
  84. def load_with_extra(filename, extra_handler=None, *args, **kwargs):
  85. """
  86. this function is intended to be used by extensions that want to load models with
  87. some extra classes in them that the usual unpickler would find suspicious.
  88. Use the extra_handler argument to specify a function that takes module and field name as text,
  89. and returns that field's value:
  90. ```python
  91. def extra(module, name):
  92. if module == 'collections' and name == 'OrderedDict':
  93. return collections.OrderedDict
  94. return None
  95. safe.load_with_extra('model.pt', extra_handler=extra)
  96. ```
  97. The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
  98. definitely unsafe.
  99. """
  100. from modules import shared
  101. try:
  102. if not shared.cmd_opts.disable_safe_unpickle:
  103. check_pt(filename, extra_handler)
  104. except pickle.UnpicklingError:
  105. errors.report(
  106. f"Error verifying pickled file from {filename}\n"
  107. "-----> !!!! The file is most likely corrupted !!!! <-----\n"
  108. "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
  109. exc_info=True,
  110. )
  111. return None
  112. except Exception:
  113. errors.report(
  114. f"Error verifying pickled file from {filename}\n"
  115. f"The file may be malicious, so the program is not going to read it.\n"
  116. f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
  117. exc_info=True,
  118. )
  119. return None
  120. return unsafe_torch_load(filename, *args, **kwargs)
  121. class Extra:
  122. """
  123. A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
  124. (because it's not your code making the torch.load call). The intended use is like this:
  125. ```
  126. import torch
  127. from modules import safe
  128. def handler(module, name):
  129. if module == 'torch' and name in ['float64', 'float16']:
  130. return getattr(torch, name)
  131. return None
  132. with safe.Extra(handler):
  133. x = torch.load('model.pt')
  134. ```
  135. """
  136. def __init__(self, handler):
  137. self.handler = handler
  138. def __enter__(self):
  139. global global_extra_handler
  140. assert global_extra_handler is None, 'already inside an Extra() block'
  141. global_extra_handler = self.handler
  142. def __exit__(self, exc_type, exc_val, exc_tb):
  143. global global_extra_handler
  144. global_extra_handler = None
  145. unsafe_torch_load = torch.load
  146. torch.load = load
  147. global_extra_handler = None