import torch import torch.nn as nn from torch.nn.functional import scaled_dot_product_attention as attention