File size: 913 Bytes
841f290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn

# 初始化词向量层
A = nn.Embedding(100, 3)  # ID范围: 0-99
B = nn.Embedding(100, 3)  # ID范围: 100-199(需确保输入ID不超过199)

# 输入数据(假设ID在0-199之间)
batch = torch.randint(0, 200, (5, 4))  # 形状: (5,18)
print(batch)

# 生成掩码
mask_a = batch < 100          # 调用A的条件
mask_b = batch >= 100         # 调用B的条件
batch_a = batch[mask_a]       # 取出A的部分
print(batch_a)
embedding_a = A(batch_a)      # 调用A的embedding
print(embedding_a)
print(embedding_a.shape)
batch_b = batch[mask_b] - 100  # 取出B的部分
print(batch_b)
embedding_b = B(batch_b)      # 调用B的embedding
print(embedding_b)
print(embedding_b.shape)

output = torch.zeros(5, 4, 3)  # 输出的形状
output[mask_a] = embedding_a  # 填充A的部分
output[mask_b] = embedding_b  # 填充B的部分
print(output)
print(output.shape)