12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- from torch.utils.checkpoint import checkpoint
- import ldm.modules.attention
- import ldm.modules.diffusionmodules.openaimodel
- def BasicTransformerBlock_forward(self, x, context=None):
- return checkpoint(self._forward, x, context)
- def AttentionBlock_forward(self, x):
- return checkpoint(self._forward, x)
- def ResBlock_forward(self, x, emb):
- return checkpoint(self._forward, x, emb)
- stored = []
- def add():
- if len(stored) != 0:
- return
- stored.extend([
- ldm.modules.attention.BasicTransformerBlock.forward,
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
- ])
- ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
- def remove():
- if len(stored) == 0:
- return
- ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
- stored.clear()
|