Lecture 4: Swin Transformer from Scratch in PyTorch - Window Attention & Cyclic Shift
Understanding Window Attention and Cyclic Shift in Swing Transfer Plot
Introduction to Window Attention
- The video discusses the final component of the Swing transfer plot, focusing on window attention and the cyclic shift padding technique related to shifted windows.
Overview of Student Block Class
- The speaker introduces the student block class, highlighting its creation of residual parameters and feed-forward mechanisms before delving into window attention.
Key Inputs for Window Attention
- Important inputs include:
- Window Size: Set at 7.
- Shifted Windows: Differentiates between window MSA (Multi-Head Self-Attention) and shifted window MSA.
- Dimensions: Hidden dimensions are defined as channels with values like 96, 192, 384, and 768.
Hierarchical Structure Variables
- The hierarchical structure includes:
- Downscaling Factor: Values are four, two, two, two across layers.
- Hidden Dimensions: Varying from stage one (96) to stage four (768).
- Number of Heads: Ranges from three to twenty-four; head dimension is consistently set at thirty-two.
Code Implementation Insights
- In code:
- Inner dimension correlates with channel numbers based on stages.
- Scale factor is derived from head dimension using a specific formula involving negative powers.
Shifting Windows Technique
- When
self.shiftedis true:
- All windows shift right and down simultaneously by half their size.
Padding Techniques Explained
- Two types of padding discussed:
- Naive Padding: Adding zeros to empty locations.
- Cyclic Padding: Faster method where sections are copied from edges to fill gaps after shifting.
Performance Comparison of Padding Methods
- Table comparisons show that cyclic padding outperforms naive padding across various stages in terms of speed during self-attention computations.
Details on Cyclic Padding Implementation
- Cyclic padding involves copying sections from top/bottom or left/right to fill empty spaces post-shift. This ensures continuity in data representation.
Coding the Cyclic Shift Functionality
- To implement cyclic shifts in code:
Cycle Shift Function Implementation
Understanding the Cycle Shift Function
- The discussion begins with the need to create a cycle shift function, emphasizing the importance of shifting elements back to their original positions.
- The implementation will utilize
torch.roll, specifying how tensors are rolled across dimensions 1 and 2 based on a given displacement input.
Numerical Example for Clarity
- A numerical example is introduced using a 9x9 matrix filled with numbers from 1 to 81, illustrating how inputs (X) are shifted by -1 in both the first and second dimensions.
- The example highlights specific sections of the matrix:
- The green section (first row) shifts downwards.
- The blue section (left column) shifts rightward.
Shifting Mechanism Explained
- It is noted that negative values are used for downward and rightward shifts, while positive values revert elements back to their original positions.