compile.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import torch
  4. def recursive_fn_factory(fn):
  5. def recursive_fn(b):
  6. if isinstance(b, dict):
  7. return {k: recursive_fn(b[k]) for k in b}
  8. if isinstance(b, list):
  9. return [recursive_fn(t) for t in b]
  10. if isinstance(b, tuple):
  11. return tuple(recursive_fn(t) for t in b)
  12. if isinstance(b, torch.Tensor):
  13. return fn(b)
  14. # Yes, writing out an explicit white list of
  15. # trivial types is tedious, but so are bugs that
  16. # come from not applying fn, when expected to have
  17. # applied it.
  18. if b is None:
  19. return b
  20. trivial_types = [bool, int]
  21. for t in trivial_types:
  22. if isinstance(b, t):
  23. return b
  24. raise TypeError(f"Unexpected type {type(b)}")
  25. return recursive_fn
  26. recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous())
  27. recursive_clone = recursive_fn_factory(torch.clone)
  28. def compile_wrapper(
  29. fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None
  30. ):
  31. compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic)
  32. def compiled_fn_wrapper(*args, **kwargs):
  33. with torch.autograd.profiler.record_function(
  34. f"compiled {fn}" if name is None else name
  35. ):
  36. cont_args = recursive_contiguous(args)
  37. cont_kwargs = recursive_contiguous(kwargs)
  38. result = compiled_fn(*cont_args, **cont_kwargs)
  39. cloned_result = recursive_clone(result)
  40. return cloned_result
  41. return compiled_fn_wrapper
  42. def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False):
  43. """
  44. Wraps a function and prints the shapes of all tensor inputs.
  45. Only prints when a new combination of shapes is seen.
  46. Thread-safe.
  47. Args:
  48. fn: Function to wrap
  49. enable_logging: Boolean flag to enable/disable logging
  50. """
  51. seen_shapes = set()
  52. def get_shape(obj):
  53. if isinstance(obj, torch.Tensor):
  54. return obj.shape
  55. elif isinstance(obj, (list, tuple)):
  56. if len(obj) > 1:
  57. return tuple(get_shape(x) for x in obj)
  58. return get_shape(obj[0])
  59. elif isinstance(obj, dict):
  60. return tuple(sorted((k, get_shape(v)) for k, v in obj.items()))
  61. else:
  62. return type(obj).__name__
  63. def wrapper(*args, **kwargs):
  64. shapes = tuple(get_shape(arg) for arg in args) + tuple(
  65. (k, get_shape(v))
  66. for k, v in kwargs.items()
  67. if isinstance(v, (torch.Tensor, list))
  68. and (len(keep_kwargs) > 0 and k in keep_kwargs)
  69. )
  70. if shapes not in seen_shapes:
  71. seen_shapes.add(shapes)
  72. if enable_logging:
  73. print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}")
  74. return fn(*args, **kwargs)
  75. # Allow toggling the flag at runtime
  76. wrapper.enable_logging = enable_logging
  77. def set_logging(enabled=False):
  78. nonlocal enable_logging
  79. enable_logging = enabled
  80. wrapper.enable_logging = enable_logging
  81. wrapper.set_logging = set_logging
  82. return wrapper