Bước | Mô Tả | Kết quả |
---|---|---|
1 | Tính điểm thô | [1, 2, 3] |
2 | Chia cho và softmax | [0.140, 0.284, 0.576] |
3 | Tính tổng trọng số lên V | [0.355, 0.617] |
def scaled_dot_product_attention(self, Q, K, V, mask=None): d_k = Q.size(-1) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) context = torch.matmul(attention_weights, V) return context, attention_weights
def forward(self, query, key, value, mask=None): batch_size, seq_len, d_model = query.size() Q = self.W_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) K = self.W_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) V = self.W_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask) attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model) output = self.W_o(attention_output) return output, attention_weights
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe)
def forward(self, x): seq_len = x.size(1) return x + self.pe[:seq_len, :].transpose(0, 1)
class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, d_ff, dropout=0.1): super().__init__() self.attention = MultiHeadAttention(d_model, n_heads) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.feed_forward = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model) ) self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None): attn_output, attn_weights = self.attention(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return x, attn_weights
Các kết nối residual giúp việc huấn luyện mạng sâu trở nên ổn định hơn.
class AttentionClassifier(nn.Module): def __init__(self, input_dim, d_model, n_heads, n_layers, n_classes): super().__init__() self.input_projection = nn.Linear(input_dim, d_model) self.pos_encoding = PositionalEncoding(d_model) self.transformer_blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_model * 4) for _ in range(n_layers) ]) self.classifier = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_model // 2, n_classes) )
def forward(self, x): x = self.input_projection(x) x = self.pos_encoding(x) attention_weights = [] for block in self.transformer_blocks: x, attn_weights = block(x) attention_weights.append(attn_weights) x = torch.mean(x, dim=1) output = self.classifier(x) return output, attention_weights
Metrics | Kết quả |
---|---|
Training Accuracy | 98.3% |
Validation Accuracy | 96.7% |
Test Accuracy | 96.0% |
Loại Chi Phí | Độ phức tạp |
---|---|
Tính toán | O(n² × d) |
Bộ nhớ | O(n²) |
Vấn đề | Giải pháp | Cách kiểm tra |
---|---|---|
Vanishing gradient | Khởi tạo trọng số đúng, residual | Theo dõi grad norm |
Overfitting | Tăng dropout, giảm model size | So sánh loss train và valid |
Hội tụ chậm | Điều chỉnh lr, thử AdamW | Quan sát tốc độ giảm lỗi |
git clone https://github.com/GruheshKurra/AttentionMechanisms.gitcd AttentionMechanismspip install -r requirements.txtjupyter notebook "Attention Mechanisms.ipynb"