hashes.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import hashlib
  2. import os.path
  3. from modules import shared
  4. import modules.cache
  5. dump_cache = modules.cache.dump_cache
  6. cache = modules.cache.cache
  7. def calculate_sha256(filename):
  8. hash_sha256 = hashlib.sha256()
  9. blksize = 1024 * 1024
  10. with open(filename, "rb") as f:
  11. for chunk in iter(lambda: f.read(blksize), b""):
  12. hash_sha256.update(chunk)
  13. return hash_sha256.hexdigest()
  14. def sha256_from_cache(filename, title, use_addnet_hash=False):
  15. hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
  16. ondisk_mtime = os.path.getmtime(filename)
  17. if title not in hashes:
  18. return None
  19. cached_sha256 = hashes[title].get("sha256", None)
  20. cached_mtime = hashes[title].get("mtime", 0)
  21. if ondisk_mtime > cached_mtime or cached_sha256 is None:
  22. return None
  23. return cached_sha256
  24. def sha256(filename, title, use_addnet_hash=False):
  25. hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
  26. sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
  27. if sha256_value is not None:
  28. return sha256_value
  29. if shared.cmd_opts.no_hashing:
  30. return None
  31. print(f"Calculating sha256 for {filename}: ", end='')
  32. if use_addnet_hash:
  33. with open(filename, "rb") as file:
  34. sha256_value = addnet_hash_safetensors(file)
  35. else:
  36. sha256_value = calculate_sha256(filename)
  37. print(f"{sha256_value}")
  38. hashes[title] = {
  39. "mtime": os.path.getmtime(filename),
  40. "sha256": sha256_value,
  41. }
  42. dump_cache()
  43. return sha256_value
  44. def addnet_hash_safetensors(b):
  45. """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
  46. hash_sha256 = hashlib.sha256()
  47. blksize = 1024 * 1024
  48. b.seek(0)
  49. header = b.read(8)
  50. n = int.from_bytes(header, "little")
  51. offset = n + 8
  52. b.seek(offset)
  53. for chunk in iter(lambda: b.read(blksize), b""):
  54. hash_sha256.update(chunk)
  55. return hash_sha256.hexdigest()