Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
# 初始化词向量层 | |
A = nn.Embedding(100, 3) # ID范围: 0-99 | |
B = nn.Embedding(100, 3) # ID范围: 100-199(需确保输入ID不超过199) | |
# 输入数据(假设ID在0-199之间) | |
batch = torch.randint(0, 200, (5, 4)) # 形状: (5,18) | |
print(batch) | |
# 生成掩码 | |
mask_a = batch < 100 # 调用A的条件 | |
mask_b = batch >= 100 # 调用B的条件 | |
batch_a = batch[mask_a] # 取出A的部分 | |
print(batch_a) | |
embedding_a = A(batch_a) # 调用A的embedding | |
print(embedding_a) | |
print(embedding_a.shape) | |
batch_b = batch[mask_b] - 100 # 取出B的部分 | |
print(batch_b) | |
embedding_b = B(batch_b) # 调用B的embedding | |
print(embedding_b) | |
print(embedding_b.shape) | |
output = torch.zeros(5, 4, 3) # 输出的形状 | |
output[mask_a] = embedding_a # 填充A的部分 | |
output[mask_b] = embedding_b # 填充B的部分 | |
print(output) | |
print(output.shape) |