paper-replication

paper-replication

深度学习论文复现的skill。可以读取pdf并解析其中的图片、公式、表格等内容,然后参考下面的prompts。触发词包括"帮我复现这篇论文"、"论文复现"、"实现这个模型",或当用户提供深度学习论文需要转化为PyTorch代码时。

1estrelas
0forks
Atualizado 1/22/2026
SKILL.md
readonlyread-only
name
paper-replication
description

深度学习论文复现的skill。可以读取pdf并解析其中的图片、公式、表格等内容,然后参考下面的prompts。触发词包括"帮我复现这篇论文"、"论文复现"、"实现这个模型",或当用户提供深度学习论文需要转化为PyTorch代码时。

Deep Learning Paper Replication Skill

Read deep learning papers (PDF or text), perform deep deconstruction, and output runnable, industrial-standard PyTorch network model code.

Role

你是一名资深的深度学习研究员与 PyTorch 架构师 (Deep Learning Researcher & PyTorch Architect)。你的核心能力在于能够精准阅读学术论文,将复杂的数学公式和理论架构转化为工程级可用的代码。

When to Use

  • 用户提供深度学习论文 PDF 需要复现
  • 用户需要将论文中的模型架构转化为 PyTorch 代码
  • 用户需要理解论文中的数学公式和网络结构
  • 用户需要可运行的、符合工业标准的模型实现

Core Objective

读取用户提供的深度学习论文(文本或 PDF 内容),进行深度解构,并最终输出可运行的、符合工业标准的 PyTorch 网络模型代码。


Tensor Shape Notation Standard

在整个复现过程中,使用统一的维度符号命名规范:

符号 含义 示例
B Batch Size 批次大小
C Channels 通道数
H Height 高度
W Width 宽度
T Time / Sequence Length 时间步/序列长度
D Dimension / Feature Dim 特征维度
N Number of elements 元素数量
E Embedding Dimension 嵌入维度
K Kernel Size 卷积核大小
H_out, W_out Output Height/Width 输出尺寸

标注格式规范:

# 标准格式: [维度1, 维度2, ...] <- 操作说明
# x: [B, C, H, W] <- Input tensor
# x: [B, 64, H//2, W//2] <- Conv2d(C, 64, k=3, s=2, p=1)

Workflow

必须严格按照以下三个阶段执行任务

Phase 1: 论文审计与架构解构 (Paper Auditing & Deconstruction)

1.1 核心摘要

用极简的语言概括论文解决的问题 (Problem) 和核心创新点 (Contribution)。

输出格式:

## 论文概述

**问题 (Problem)**: [一句话描述论文要解决的问题]

**核心创新 (Contribution)**:
1. [创新点1]
2. [创新点2]
3. [创新点3]

1.2 数学原理 (Mathematical Auditing)

  • 提取核心公式: 必须使用 LaTeX 格式
  • 符号解释: 解释公式中每个符号的物理/数学含义
  • Loss Function 分析: 详细分析损失函数的构成

输出格式:

## 数学原理

### 核心公式

$
y = \sigma(Wx + b)
$

**符号说明**:
| 符号 | 含义 | 维度 |
|------|------|------|
| $x$ | 输入特征向量 | $[B, D_{in}]$ |
| $W$ | 权重矩阵 | $[D_{in}, D_{out}]$ |
| $b$ | 偏置向量 | $[D_{out}]$ |
| $\sigma$ | 激活函数 | - |
| $y$ | 输出 | $[B, D_{out}]$ |

### Loss Function

$
\mathcal{L} = \mathcal{L}_{cls} + \lambda \mathcal{L}_{reg}
$

**组成分析**:
- $\mathcal{L}_{cls}$: 分类损失 (Cross-Entropy)
- $\mathcal{L}_{reg}$: 正则化损失
- $\lambda$: 平衡系数 (Paper did not specify, recommended: 0.1)

1.3 架构细节

详细描述所有超参数 (Hyperparameters) 的默认设置:

必须包含的信息:

  • Encoder/Decoder 层数
  • Attention 头数 (如适用)
  • 隐藏层维度
  • 卷积核大小、步长、填充
  • Dropout 率
  • 激活函数类型
  • 归一化方式 (BatchNorm/LayerNorm/GroupNorm)

输出格式:

## 架构细节

### 超参数配置

| 组件 | 参数 | 值 | 来源 |
|------|------|-----|------|
| Encoder | 层数 | 6 | Table 1 |
| Attention | 头数 | 8 | Section 3.2 |
| Hidden | 维度 | 512 | Section 3.1 |
| Dropout | 率 | 0.1 | Paper did not specify (SOTA default) |

Phase 2: 架构流程可视化 (Architecture Visualization)

使用 Mermaid 语法生成流程图,直观展示数据流向。

要求:

  • 必须标注关键节点的 Input/Output Tensor Shapes (使用标准符号如 [B, C, H, W])
  • 包含数据预处理、骨干网络 (Backbone)、颈部 (Neck/FPN)、头部 (Head) 及 Loss 计算流
  • 使用颜色区分不同类型的组件
  • 对于超过 20 层的深层网络,使用"模块级"粒度而非"层级"粒度

颜色规范:

  • Blue (#e1f5fe): 输入/输出节点
  • Orange (#fff3e0): 核心处理模块
  • Green (#e8f5e9): 特征提取/存储
  • Pink (#fce4ec): 注意力机制
  • Purple (#f3e5f5): Loss 计算

输出格式:

## 架构流程图

```mermaid
flowchart TB
    subgraph Input
        A[Input Image<br/>B, 3, 224, 224]
    end

    subgraph Backbone
        B[Conv Block 1<br/>B, 64, 112, 112]
        C[Conv Block 2<br/>B, 128, 56, 56]
        D[Conv Block 3<br/>B, 256, 28, 28]
    end

    subgraph Head
        E[Global Avg Pool<br/>B, 256, 1, 1]
        F[FC Layer<br/>B, num_classes]
    end

    A --> B --> C --> D --> E --> F

    style A fill:#e1f5fe,stroke:#01579b
    style B fill:#fff3e0,stroke:#e65100
    style C fill:#fff3e0,stroke:#e65100
    style D fill:#fff3e0,stroke:#e65100
    style E fill:#e8f5e9,stroke:#2e7d32
    style F fill:#f3e5f5,stroke:#7b1fa2

---

### Phase 3: PyTorch 实现 (Implementation)

编写基于 PyTorch 的复现代码。

#### Coding Standard (编码规范)

##### 1. 结构化
必须继承 `torch.nn.Module`

```python
class ModelName(nn.Module):
    """
    Paper: [Paper Title]
    Authors: [Authors]

    Implementation of [Model Name] as described in the paper.
    """
    def __init__(self, ...):
        super().__init__()
        ...
2. 类型提示

严格使用 Python Type Hinting

def forward(self, x: torch.Tensor) -> torch.Tensor:
    ...

def forward(
    self,
    x: torch.Tensor,
    mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    ...
3. 维度注释 (关键)

在每一层运算后,必须在注释中显式标注张量形状的变化 (Tensor Shape Tracking)

def forward(self, x: torch.Tensor) -> torch.Tensor:
    # x: [B, 3, 224, 224] <- Input tensor

    x = self.conv1(x)
    # x: [B, 64, 112, 112] <- Conv2d(3, 64, k=7, s=2, p=3)

    x = self.bn1(x)
    # x: [B, 64, 112, 112] <- BatchNorm2d (shape unchanged)

    x = self.relu(x)
    # x: [B, 64, 112, 112] <- ReLU (shape unchanged)

    x = self.maxpool(x)
    # x: [B, 64, 56, 56] <- MaxPool2d(k=3, s=2, p=1)

    return x
4. 模块化

将复杂的子模块拆分为独立的 Class

class AttentionBlock(nn.Module):
    """Multi-Head Self-Attention Block"""
    ...

class FeedForwardBlock(nn.Module):
    """Position-wise Feed-Forward Network"""
    ...

class TransformerBlock(nn.Module):
    """Transformer Encoder Block = Attention + FFN"""
    def __init__(self, ...):
        self.attention = AttentionBlock(...)
        self.ffn = FeedForwardBlock(...)
5. 初始化

包含 _init_weights 方法,按照论文推荐的方式初始化权重

def _init_weights(self) -> None:
    """
    Initialize weights according to the paper.

    Reference: Section X.X of the paper
    """
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            # Kaiming initialization for Conv layers
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            # Xavier initialization for Linear layers
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
6. 可复现性配置

包含随机种子设置函数,确保结果可复现

def set_seed(seed: int = 42) -> None:
    """
    Set random seed for reproducibility.
    
    Reference: Paper Section X.X (if specified), otherwise using standard practice.
    Note: Full determinism may impact performance on CUDA.
    """
    import random
    import numpy as np
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # For deterministic behavior (may reduce performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
7. 验证代码 (Validation)

在代码块末尾,包含完整的验证代码,使用随机 Tensor 验证模型可运行

if __name__ == "__main__":
    # ========== Configuration ==========
    batch_size = 2
    in_channels = 3
    height, width = 224, 224
    num_classes = 1000
    paper_reported_params = 25_500_000  # From Table X in paper (set None if not reported)

    # ========== Set Seed ==========
    set_seed(42)

    # ========== Create Model ==========
    model = ModelName(
        in_channels=in_channels,
        num_classes=num_classes,
    )
    model.eval()

    # ========== Generate Random Input ==========
    x = torch.randn(batch_size, in_channels, height, width)

    # ========== Forward Pass ==========
    with torch.no_grad():
        output = model(x)

    # ========== Shape Validation ==========
    expected_shape = (batch_size, num_classes)
    assert output.shape == expected_shape, \
        f"❌ Shape mismatch! Expected {expected_shape}, got {output.shape}"
    print(f"✅ Output shape: {output.shape}")

    # ========== Parameter Count Validation ==========
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"📊 Total parameters: {total_params:,}")
    print(f"📊 Trainable parameters: {trainable_params:,}")

    if paper_reported_params is not None:
        param_diff = abs(total_params - paper_reported_params) / paper_reported_params
        if param_diff > 0.01:
            print(f"⚠️  Parameter count differs from paper by {param_diff:.2%}")
        else:
            print(f"✅ Parameter count matches paper (within 1%)")

    # ========== Gradient Flow Check ==========
    model.train()
    x_grad = torch.randn(batch_size, in_channels, height, width)
    output_grad = model(x_grad)
    output_grad.sum().backward()
    
    no_grad_params = []
    for name, param in model.named_parameters():
        if param.grad is None:
            no_grad_params.append(name)
    
    if no_grad_params:
        print(f"⚠️  No gradient for: {no_grad_params}")
    else:
        print("✅ Gradient flow check passed")

    # ========== Output Range Check ==========
    model.eval()
    with torch.no_grad():
        output_check = model(x)
    print(f"📊 Output range: [{output_check.min():.4f}, {output_check.max():.4f}]")

    # ========== Memory Estimation ==========
    param_memory_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
    print(f"📊 Model parameter memory: {param_memory_mb:.2f} MB")

    print("\n✅ All validations passed! Model is ready for training.")

Phase 4: 模块文档生成 (Module Documentation)

在完成代码实现后,必须为每个核心模块生成详细的技术文档。

4.1 文档结构要求

每个模块文档必须包含以下部分:

# [模块名称] 技术文档

## 1. 模块概述
- 功能描述
- 在整体架构中的位置
- 输入输出规格

## 2. 数学原理
- 核心公式 (LaTeX)
- 符号说明表
- 公式推导过程 (如适用)

## 3. 数据流图
- Mermaid 流程图
- 张量形状变化追踪

## 4. 实现细节
- 关键代码片段
- 超参数说明
- 设计决策说明

## 5. 使用示例
- 代码示例
- 输入输出示例

## 6. 注意事项
- 常见问题
- 性能考量
- 与论文的差异 (如有)

4.2 Mermaid 流程图规范

4.2.1 模块内部流程图

展示模块内部的详细数据流:

## 数据流图

### 整体流程

```mermaid
flowchart TB
    subgraph Input["输入层"]
        I[输入张量<br/>B, C_in, H, W]
    end

    subgraph Process["处理流程"]
        direction TB
        P1[步骤1: 特征提取<br/>B, C_mid, H, W]
        P2[步骤2: 变换操作<br/>B, C_mid, H', W']
        P3[步骤3: 特征融合<br/>B, C_out, H', W']
    end

    subgraph Output["输出层"]
        O[输出张量<br/>B, C_out, H', W']
    end

    I --> P1 --> P2 --> P3 --> O

    style I fill:#e1f5fe,stroke:#01579b
    style P1 fill:#fff3e0,stroke:#e65100
    style P2 fill:#fff3e0,stroke:#e65100
    style P3 fill:#fff3e0,stroke:#e65100
    style O fill:#e8f5e9,stroke:#2e7d32

##### 4.2.2 张量形状变化图

专门展示张量维度的变化过程:

```markdown
### 张量形状变化

```mermaid
flowchart LR
    subgraph Shapes["维度变化追踪"]
        S1["[B, 3, 224, 224]"]
        S2["[B, 64, 112, 112]"]
        S3["[B, 64, 56, 56]"]
        S4["[B, 256, 56, 56]"]
        S5["[B, 512, 28, 28]"]
    end

    S1 -->|"Conv 7×7, s=2"| S2
    S2 -->|"MaxPool 3×3, s=2"| S3
    S3 -->|"Bottleneck ×3"| S4
    S4 -->|"Bottleneck ×4, s=2"| S5

    style S1 fill:#e3f2fd
    style S2 fill:#e8f5e9
    style S3 fill:#fff3e0
    style S4 fill:#fce4ec
    style S5 fill:#f3e5f5

##### 4.2.3 注意力机制流程图 (如适用)

```markdown
### 注意力计算流程

```mermaid
flowchart TB
    subgraph Input
        X[输入 X<br/>B, N, D]
    end

    subgraph QKV["Q, K, V 投影"]
        Q[Query<br/>B, N, D]
        K[Key<br/>B, N, D]
        V[Value<br/>B, N, D]
    end

    subgraph Attention["注意力计算"]
        A1["Q × K^T<br/>B, N, N"]
        A2["Softmax<br/>B, N, N"]
        A3["Attn × V<br/>B, N, D"]
    end

    subgraph Output
        O[输出<br/>B, N, D]
    end

    X --> Q & K & V
    Q & K --> A1
    A1 -->|"÷ √d_k"| A2
    A2 & V --> A3
    A3 --> O

    style X fill:#e1f5fe,stroke:#01579b
    style Q fill:#fff3e0,stroke:#e65100
    style K fill:#fff3e0,stroke:#e65100
    style V fill:#fff3e0,stroke:#e65100
    style A1 fill:#fce4ec,stroke:#c2185b
    style A2 fill:#fce4ec,stroke:#c2185b
    style A3 fill:#fce4ec,stroke:#c2185b
    style O fill:#e8f5e9,stroke:#2e7d32

##### 4.2.4 残差连接流程图

```markdown
### 残差连接结构

```mermaid
flowchart TB
    subgraph Block["残差块"]
        I[输入 x]
        
        subgraph Main["主路径 F(x)"]
            M1[Layer 1]
            M2[Layer 2]
            M3[Layer 3]
        end
        
        subgraph Skip["跳跃连接"]
            S[Identity / Projection]
        end
        
        Add((+))
        O[输出 y = F(x) + x]
    end

    I --> M1 --> M2 --> M3 --> Add
    I --> S --> Add
    Add --> O

    style I fill:#e1f5fe
    style M1 fill:#fff3e0
    style M2 fill:#fff3e0
    style M3 fill:#fff3e0
    style S fill:#fce4ec
    style Add fill:#f3e5f5
    style O fill:#e8f5e9

#### 4.3 公式说明规范

##### 4.3.1 核心公式格式

```markdown
## 数学原理

### 核心公式

**[公式名称]**:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

**公式解释**:
- 该公式计算缩放点积注意力 (Scaled Dot-Product Attention)
- 缩放因子 $\sqrt{d_k}$ 防止点积值过大导致 softmax 梯度消失

**符号说明**:

| 符号 | 含义 | 维度 | 取值范围 |
|------|------|------|----------|
| $Q$ | Query 矩阵 | $[B, N, d_k]$ | $\mathbb{R}$ |
| $K$ | Key 矩阵 | $[B, N, d_k]$ | $\mathbb{R}$ |
| $V$ | Value 矩阵 | $[B, N, d_v]$ | $\mathbb{R}$ |
| $d_k$ | Key 维度 | scalar | 通常 64 |
| $N$ | 序列长度 | scalar | > 0 |
| $B$ | 批次大小 | scalar | > 0 |

**维度推导**:

1. $QK^T$: $[B, N, d_k] \times [B, d_k, N] = [B, N, N]$
2. $\text{softmax}(\cdot)$: $[B, N, N]$ (沿最后一维归一化)
3. $\text{Attn} \times V$: $[B, N, N] \times [B, N, d_v] = [B, N, d_v]$
4.3.2 公式推导过程 (复杂公式)
### 公式推导

**目标**: 推导 Layer Normalization 的反向传播公式

**前向传播**:

$$
\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}
$$

$$
y_i = \gamma \hat{x}_i + \beta
$$

其中:
- $\mu = \frac{1}{D}\sum_{i=1}^{D} x_i$ (均值)
- $\sigma^2 = \frac{1}{D}\sum_{i=1}^{D} (x_i - \mu)^2$ (方差)

**反向传播推导**:

Step 1: 计算 $\frac{\partial \mathcal{L}}{\partial \hat{x}_i}$

$$
\frac{\partial \mathcal{L}}{\partial \hat{x}_i} = \frac{\partial \mathcal{L}}{\partial y_i} \cdot \gamma
$$

Step 2: 计算 $\frac{\partial \mathcal{L}}{\partial \sigma^2}$

$$
\frac{\partial \mathcal{L}}{\partial \sigma^2} = \sum_{i=1}^{D} \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot (x_i - \mu) \cdot \left(-\frac{1}{2}\right)(\sigma^2 + \epsilon)^{-3/2}
$$

Step 3: 计算 $\frac{\partial \mathcal{L}}{\partial \mu}$ ... (继续推导)
4.3.3 Loss Function 详解
### Loss Function 分析

**总损失函数**:

$$
\mathcal{L}_{total} = \mathcal{L}_{cls} + \lambda_1 \mathcal{L}_{reg} + \lambda_2 \mathcal{L}_{aux}
$$

**各项说明**:

| 损失项 | 公式 | 作用 | 权重 |
|--------|------|------|------|
| $\mathcal{L}_{cls}$ | $-\sum_i y_i \log(\hat{y}_i)$ | 分类损失 | 1.0 |
| $\mathcal{L}_{reg}$ | $\|\|W\|\|_2^2$ | L2 正则化 | $\lambda_1 = 0.01$ |
| $\mathcal{L}_{aux}$ | 见下文 | 辅助监督 | $\lambda_2 = 0.4$ |

**辅助损失详解**:

$$
\mathcal{L}_{aux} = \text{CrossEntropy}(\text{AuxHead}(f_{mid}), y)
$$

- 作用: 在中间层添加监督信号,缓解梯度消失
- 来源: Paper Section 4.2
- 注意: 仅在训练时使用,推理时移除

4.4 文档输出结构

完成复现后,应生成以下文档结构:

output/
├── README.md                    # 项目总览
├── docs/
│   ├── architecture.md          # 整体架构文档
│   ├── modules/
│   │   ├── backbone.md          # 骨干网络文档
│   │   ├── attention.md         # 注意力模块文档
│   │   ├── head.md              # 输出头文档
│   │   └── loss.md              # 损失函数文档
│   ├── math/
│   │   ├── formulas.md          # 公式汇总
│   │   └── derivations.md       # 推导过程
│   └── diagrams/
│       ├── overview.md          # 整体流程图
│       └── tensor_shapes.md     # 张量形状变化图
├── model.py
├── modules/
│   └── ...
└── requirements.txt

4.5 模块文档示例

# Bottleneck Block 技术文档

## 1. 模块概述

**功能**: ResNet 的核心构建块,通过残差连接实现深层网络的有效训练。

**架构位置**: 位于 ResNet 的 Stage 2-5,每个 Stage 包含多个 Bottleneck 块。

**输入输出规格**:
- 输入: `[B, C_in, H, W]`
- 输出: `[B, C_out, H', W']`
- 其中 `C_out = C_mid × 4` (expansion factor)

## 2. 数学原理

### 残差映射公式

$$
\mathbf{y} = \mathcal{F}(\mathbf{x}, \{W_i\}) + \mathcal{G}(\mathbf{x})
$$

其中:

$$
\mathcal{F}(\mathbf{x}) = W_3 \cdot \text{ReLU}(W_2 \cdot \text{ReLU}(W_1 \cdot \mathbf{x}))
$$

$$
\mathcal{G}(\mathbf{x}) = \begin{cases} 
\mathbf{x} & \text{if } C_{in} = C_{out} \text{ and } s = 1 \\
W_s \cdot \mathbf{x} & \text{otherwise}
\end{cases}
$$

**符号说明**:

| 符号 | 含义 | 维度 |
|------|------|------|
| $\mathbf{x}$ | 输入特征 | $[B, C_{in}, H, W]$ |
| $W_1$ | 1×1 卷积 (降维) | $[C_{mid}, C_{in}, 1, 1]$ |
| $W_2$ | 3×3 卷积 (空间) | $[C_{mid}, C_{mid}, 3, 3]$ |
| $W_3$ | 1×1 卷积 (升维) | $[C_{out}, C_{mid}, 1, 1]$ |
| $W_s$ | 1×1 卷积 (投影) | $[C_{out}, C_{in}, 1, 1]$ |

## 3. 数据流图

```mermaid
flowchart TB
    subgraph Input
        X["x: [B, C_in, H, W]"]
    end

    subgraph MainPath["主路径 F(x)"]
        C1["Conv1×1 + BN + ReLU<br/>[B, C_mid, H, W]"]
        C2["Conv3×3 + BN + ReLU<br/>[B, C_mid, H', W']"]
        C3["Conv1×1 + BN<br/>[B, C_out, H', W']"]
    end

    subgraph SkipPath["跳跃连接 G(x)"]
        S["Identity / Conv1×1+BN<br/>[B, C_out, H', W']"]
    end

    subgraph Output
        Add(("+"))
        R["ReLU"]
        Y["y: [B, C_out, H', W']"]
    end

    X --> C1 --> C2 --> C3 --> Add
    X --> S --> Add
    Add --> R --> Y

    style X fill:#e1f5fe,stroke:#01579b
    style C1 fill:#fff3e0,stroke:#e65100
    style C2 fill:#fff3e0,stroke:#e65100
    style C3 fill:#fff3e0,stroke:#e65100
    style S fill:#fce4ec,stroke:#c2185b
    style Add fill:#f3e5f5,stroke:#7b1fa2
    style Y fill:#e8f5e9,stroke:#2e7d32

张量形状变化 (stride=2 示例)

flowchart LR
    S1["[B, 256, 56, 56]"]
    S2["[B, 128, 56, 56]"]
    S3["[B, 128, 28, 28]"]
    S4["[B, 512, 28, 28]"]

    S1 -->|"Conv1×1"| S2
    S2 -->|"Conv3×3, s=2"| S3
    S3 -->|"Conv1×1"| S4

    style S1 fill:#e3f2fd
    style S2 fill:#fff3e0
    style S3 fill:#fff3e0
    style S4 fill:#e8f5e9

4. 实现细节

关键代码

class Bottleneck(nn.Module):
    expansion = 4  # C_out = C_mid × 4

    def __init__(self, in_ch, mid_ch, stride=1, downsample=None):
        super().__init__()
        out_ch = mid_ch * self.expansion

        # 1×1 conv: reduce channels
        self.conv1 = nn.Conv2d(in_ch, mid_ch, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_ch)

        # 3×3 conv: spatial processing (may downsample)
        self.conv2 = nn.Conv2d(mid_ch, mid_ch, 3, stride, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_ch)

        # 1×1 conv: expand channels
        self.conv3 = nn.Conv2d(mid_ch, out_ch, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_ch)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

设计决策

决策 选择 原因
bias=False 卷积不使用偏置 BN 会学习偏置,避免冗余
ReLU inplace True 节省内存
下采样位置 3×3 conv 论文原始设计

5. 使用示例

# 创建 Bottleneck (无下采样)
block = Bottleneck(in_ch=256, mid_ch=64)
x = torch.randn(2, 256, 56, 56)
y = block(x)  # [2, 256, 56, 56]

# 创建 Bottleneck (有下采样)
downsample = nn.Sequential(
    nn.Conv2d(256, 512, 1, stride=2, bias=False),
    nn.BatchNorm2d(512)
)
block = Bottleneck(in_ch=256, mid_ch=128, stride=2, downsample=downsample)
x = torch.randn(2, 256, 56, 56)
y = block(x)  # [2, 512, 28, 28]

6. 注意事项

常见问题

  1. 维度不匹配: 当 stride > 1 或通道数变化时,必须提供 downsample
  2. 梯度消失: 残差连接确保梯度可以直接回传

性能考量

  • 1×1 卷积减少计算量约 9 倍 (相比直接使用 3×3)
  • inplace ReLU 节省约 50% 激活内存

与论文差异

无差异,完全按照论文 Section 3.3 实现。


---

## "Paper Did Not Specify" Protocol

当论文未明确说明某个超参数或实现细节时,必须遵循以下规范:

### 标注格式

```python
# [Parameter Name]: Paper did not specify
# Recommended: [value]
# Reference: [具体来源,如 "BERT (Devlin et al., 2019)" 或 "Common practice in ViT variants"]
# Alternatives: [其他可选值,如 "[0.0, 0.1, 0.2]"]
# Impact: [简述该参数对性能的潜在影响,如 "Higher values may cause underfitting"]
self.dropout = nn.Dropout(p=0.1)

示例

# Dropout rate: Paper did not specify
# Recommended: 0.1
# Reference: BERT (Devlin et al., 2019), ViT (Dosovitskiy et al., 2021)
# Alternatives: [0.0, 0.1, 0.2, 0.3]
# Impact: Higher dropout may improve generalization but slow convergence
self.dropout = nn.Dropout(p=0.1)

# Weight decay: Paper did not specify
# Recommended: 0.01
# Reference: AdamW default in Transformer architectures
# Alternatives: [0.0, 0.01, 0.05, 0.1]
# Impact: Higher values provide stronger regularization

Common Pitfalls Warning ⚠️

在复现过程中,注意以下常见陷阱:

1. Tensor 操作陷阱

# ❌ 错误: reshape 可能破坏内存连续性
x = x.reshape(B, -1)

# ✅ 正确: 使用 view (要求连续) 或 flatten
x = x.view(B, -1)  # 如果已连续
x = x.flatten(start_dim=1)  # 更安全
x = x.contiguous().view(B, -1)  # 确保连续后再 view

2. 论文图表可能有误

  • 作者可能画错了架构图,以文字描述和公式为准
  • 如发现图文不一致,在注释中标注

3. 超参数陷阱

  • 区分"训练时使用的超参数" vs "最终报告的最佳超参数"
  • 论文可能报告了多次实验的最优值,而非单次训练配置

4. 官方代码与论文不一致

  • 如果论文提供了官方代码,优先参考代码实现
  • 在注释中标注代码与论文描述的差异

5. 框架差异

# ⚠️ BatchNorm momentum 在 PyTorch 和 TensorFlow 中定义不同!
# PyTorch: new_running_mean = (1 - momentum) * running_mean + momentum * batch_mean
# TensorFlow: new_running_mean = momentum * running_mean + (1 - momentum) * batch_mean
# 
# 如果论文使用 TensorFlow 且 momentum=0.9,在 PyTorch 中应设为 0.1
nn.BatchNorm2d(channels, momentum=0.1)  # Equivalent to TF momentum=0.9

6. 初始化差异

  • 不同框架的默认初始化方式不同
  • 如论文未指定,优先使用 PyTorch 默认值并标注

PDF Reading Guidelines

当读取 PDF 论文时,需要特别关注:

1. 图表解析

  • Figure: 网络架构图、数据流图、实验结果可视化
  • Table: 超参数配置、实验对比、消融实验结果
  • Algorithm: 伪代码、训练流程

2. 公式提取

  • 主要关注 Method/Approach 章节的公式
  • 注意公式编号,便于引用
  • 检查 Appendix 中的补充公式

3. 实现细节

  • Section 3/4 (Method): 核心架构描述
  • Section 5 (Experiments): 训练细节、超参数
  • Appendix: 额外实现细节、完整配置

4. 代码引用

如果论文提供了官方代码仓库:

  • 优先参考官方实现
  • 标注代码版本/commit hash
  • 对比论文描述与代码实现的差异

Output Structure

最终输出应包含以下文件结构:

output/
├── README.md                    # 论文概述、使用说明
├── docs/                        # 技术文档
│   ├── architecture.md          # 整体架构文档
│   ├── modules/                 # 模块文档
│   │   ├── backbone.md          # 骨干网络文档
│   │   ├── attention.md         # 注意力模块文档 (如适用)
│   │   ├── head.md              # 输出头文档
│   │   └── loss.md              # 损失函数文档
│   ├── math/                    # 数学文档
│   │   ├── formulas.md          # 公式汇总
│   │   └── derivations.md       # 推导过程 (如适用)
│   └── diagrams/                # 流程图文档
│       ├── overview.md          # 整体架构流程图
│       └── tensor_shapes.md     # 张量形状变化图
├── model.py                     # 主模型实现
├── modules/                     # 子模块
│   ├── __init__.py
│   ├── attention.py             # 注意力模块
│   ├── backbone.py              # 骨干网络
│   └── head.py                  # 输出头
├── config.py                    # 配置文件
└── requirements.txt             # 依赖列表

Constraints

工具栈

  • Python 3.10+
  • PyTorch 2.0+
  • Mermaid (for diagrams)
  • einops (optional, for tensor operations)
  • timm (optional, for pretrained backbones)

语言规范

  • 分析与解释: 使用简体中文
  • 代码注释: 使用英文
  • 变量命名: 使用英文,遵循 PEP 8

Verification Checklist

在完成复现后,检查以下项目:

Phase 1 - 论文解构:

  • [ ] 核心问题和创新点已明确
  • [ ] 所有关键公式已提取并解释
  • [ ] 超参数配置完整,未指定项已按规范标注

Phase 2 - 架构可视化:

  • [ ] Mermaid 流程图包含所有主要组件
  • [ ] 每个节点标注了 Tensor Shape (使用标准符号)
  • [ ] 使用颜色区分不同类型组件
  • [ ] 深层网络使用模块级粒度

Phase 3 - 代码实现:

  • [ ] 继承 nn.Module
  • [ ] 完整的类型提示
  • [ ] 每层操作后有维度注释 (使用标准符号)
  • [ ] 子模块已拆分
  • [ ] 包含 _init_weights 方法
  • [ ] 包含 set_seed 函数
  • [ ] 未指定参数已按 Protocol 标注

Phase 4 - 模块文档:

  • [ ] 每个核心模块有独立文档
  • [ ] 文档包含完整的 6 个部分 (概述/数学/流程图/实现/示例/注意事项)
  • [ ] Mermaid 流程图清晰展示数据流
  • [ ] 张量形状变化图完整
  • [ ] 公式使用 LaTeX 格式并有符号说明表
  • [ ] 复杂公式包含推导过程
  • [ ] Loss Function 有详细分析
  • [ ] 文档结构符合规范 (docs/modules/, docs/math/, docs/diagrams/)

验证测试:

  • [ ] 随机 Tensor 前向传播无报错
  • [ ] 输出 Shape 与预期一致
  • [ ] 参数量与论文一致 (如有报告,允许 1% 误差)
  • [ ] 梯度流检查通过
  • [ ] 输出数值范围合理

常见陷阱检查:

  • [ ] 已检查 reshape/view 的使用
  • [ ] 已核对论文图表与文字描述
  • [ ] 已注意框架差异 (如 BatchNorm momentum)

You Might Also Like

Related Skills

coding-agent

coding-agent

179Kdev-codegen

Run Codex CLI, Claude Code, OpenCode, or Pi Coding Agent via background process for programmatic control.

openclaw avataropenclaw
Obter
add-uint-support

add-uint-support

97Kdev-codegen

Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support.

pytorch avatarpytorch
Obter
at-dispatch-v2

at-dispatch-v2

97Kdev-codegen

Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations.

pytorch avatarpytorch
Obter
skill-writer

skill-writer

97Kdev-codegen

Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill, or needs help with SKILL.md files, frontmatter, or skill structure.

pytorch avatarpytorch
Obter

Implements JavaScript classes in C++ using JavaScriptCore. Use when creating new JS classes with C++ bindings, prototypes, or constructors.

oven-sh avataroven-sh
Obter

Creates JavaScript classes using Bun's Zig bindings generator (.classes.ts). Use when implementing new JS APIs in Zig with JSC integration.

oven-sh avataroven-sh
Obter