| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import torch
- def recursive_fn_factory(fn):
- def recursive_fn(b):
- if isinstance(b, dict):
- return {k: recursive_fn(b[k]) for k in b}
- if isinstance(b, list):
- return [recursive_fn(t) for t in b]
- if isinstance(b, tuple):
- return tuple(recursive_fn(t) for t in b)
- if isinstance(b, torch.Tensor):
- return fn(b)
- # Yes, writing out an explicit white list of
- # trivial types is tedious, but so are bugs that
- # come from not applying fn, when expected to have
- # applied it.
- if b is None:
- return b
- trivial_types = [bool, int]
- for t in trivial_types:
- if isinstance(b, t):
- return b
- raise TypeError(f"Unexpected type {type(b)}")
- return recursive_fn
- recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous())
- recursive_clone = recursive_fn_factory(torch.clone)
- def compile_wrapper(
- fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None
- ):
- compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic)
- def compiled_fn_wrapper(*args, **kwargs):
- with torch.autograd.profiler.record_function(
- f"compiled {fn}" if name is None else name
- ):
- cont_args = recursive_contiguous(args)
- cont_kwargs = recursive_contiguous(kwargs)
- result = compiled_fn(*cont_args, **cont_kwargs)
- cloned_result = recursive_clone(result)
- return cloned_result
- return compiled_fn_wrapper
- def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False):
- """
- Wraps a function and prints the shapes of all tensor inputs.
- Only prints when a new combination of shapes is seen.
- Thread-safe.
- Args:
- fn: Function to wrap
- enable_logging: Boolean flag to enable/disable logging
- """
- seen_shapes = set()
- def get_shape(obj):
- if isinstance(obj, torch.Tensor):
- return obj.shape
- elif isinstance(obj, (list, tuple)):
- if len(obj) > 1:
- return tuple(get_shape(x) for x in obj)
- return get_shape(obj[0])
- elif isinstance(obj, dict):
- return tuple(sorted((k, get_shape(v)) for k, v in obj.items()))
- else:
- return type(obj).__name__
- def wrapper(*args, **kwargs):
- shapes = tuple(get_shape(arg) for arg in args) + tuple(
- (k, get_shape(v))
- for k, v in kwargs.items()
- if isinstance(v, (torch.Tensor, list))
- and (len(keep_kwargs) > 0 and k in keep_kwargs)
- )
- if shapes not in seen_shapes:
- seen_shapes.add(shapes)
- if enable_logging:
- print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}")
- return fn(*args, **kwargs)
- # Allow toggling the flag at runtime
- wrapper.enable_logging = enable_logging
- def set_logging(enabled=False):
- nonlocal enable_logging
- enable_logging = enabled
- wrapper.enable_logging = enable_logging
- wrapper.set_logging = set_logging
- return wrapper
|