| from torch.utils.checkpoint import checkpoint |
|
|
| import ldm.modules.attention |
| import ldm.modules.diffusionmodules.openaimodel |
|
|
|
|
| def BasicTransformerBlock_forward(self, x, context=None): |
| return checkpoint(self._forward, x, context) |
|
|
|
|
| def AttentionBlock_forward(self, x): |
| return checkpoint(self._forward, x) |
|
|
|
|
| def ResBlock_forward(self, x, emb): |
| return checkpoint(self._forward, x, emb) |
|
|
|
|
| stored = [] |
|
|
|
|
| def add(): |
| if len(stored) != 0: |
| return |
|
|
| stored.extend([ |
| ldm.modules.attention.BasicTransformerBlock.forward, |
| ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, |
| ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward |
| ]) |
|
|
| ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward |
| ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward |
| ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward |
|
|
|
|
| def remove(): |
| if len(stored) == 0: |
| return |
|
|
| ldm.modules.attention.BasicTransformerBlock.forward = stored[0] |
| ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] |
| ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] |
|
|
| stored.clear() |
|
|
|
|