
Rwkv Architecture
Give your coding agent accurate RWKV time-mixing, channel-mixing, and WKV recurrence details when designing or explaining linear-time sequence models.
Install
npx skills add https://github.com/orchestra-research/ai-research-skills --skill rwkv-architectureWhat is this skill?
- Contrasts O(n²) attention with O(n) WKV recurrence in Time-Mixing blocks
- Documents Time-Mix and Channel-Mix alternation with receptance/key/value projections
- Includes full RWKV_TimeMix-style module sketch with time-decay and bonus parameters
- Explains linear-time state updates (aa/ab recurrence) for long sequences
- Serves as implementation-oriented reference for custom or efficient inference stacks
Adoption & trust: 1 installs on skills.sh; 9.4k GitHub stars; 3/3 security scanners passed (skills.sh audits).
Recommended Skills
Paper Context Resolverlllllllama/ai-paper-reproduction-skill
Repo Intake And Planlllllllama/ai-paper-reproduction-skill
Env And Assets Bootstraplllllllama/ai-paper-reproduction-skill
Minimal Run And Auditlllllllama/ai-paper-reproduction-skill
Analyze Projectlllllllama/rigorpilot-skills
Ai Research Reproductionlllllllama/rigorpilot-skills
Journey fit
Primary fit
Architecture reference material is consumed when exploring model choices before implementation, which maps to idea-phase technical research. RWKV block-level math and pseudocode are deep research artifacts, not shipping or growth tasks.
Common Questions / FAQ
Is Rwkv Architecture safe to install?
skills.sh reports 3 of 3 security scanners passed. Review the Security Audits panel on this page before installing in production.
SKILL.md
READMESKILL.md - Rwkv Architecture
# RWKV Architecture Details ## Time-Mixing and Channel-Mixing Blocks RWKV alternates between **Time-Mixing** (sequence processing) and **Channel-Mixing** (feature processing) blocks. ### Time-Mixing Block (WKV Operation) The core innovation is the **WKV (Weighted Key-Value)** mechanism: ```python # Traditional Attention (O(n²)) scores = Q @ K.T / sqrt(d) # n×n matrix attention = softmax(scores) output = attention @ V # RWKV Time-Mixing (O(n)) # Compute WKV in linear time using recurrence for t in range(T): wkv[t] = (exp(w) * k[t] @ v[t] + a[t] * aa[t]) / (exp(w) * k[t] + a[t] * ab[t]) aa[t+1] = exp(w) * k[t] @ v[t] + exp(-u) * aa[t] ab[t+1] = exp(w) * k[t] + exp(-u) * ab[t] ``` **Full Time-Mixing implementation**: ```python class RWKV_TimeMix(nn.Module): def __init__(self, d_model, n_layer): super().__init__() self.d_model = d_model # Linear projections self.key = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(d_model, d_model, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.output = nn.Linear(d_model, d_model, bias=False) # Time-mixing parameters self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) # Time-decay and bonus self.time_decay = nn.Parameter(torch.ones(d_model)) # w self.time_first = nn.Parameter(torch.ones(d_model)) # u def forward(self, x, state=None): B, T, C = x.shape # Time-shift mixing (interpolate with previous token) if state is None: state = torch.zeros(B, C, 3, device=x.device) # [aa, ab, x_prev] x_prev = state[:, :, 2].unsqueeze(1) # Previous x xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k) xv = x * self.time_mix_v + x_prev * (1 - self.time_mix_v) xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r) # Compute k, v, r k = self.key(xk) v = self.value(xv) r = self.receptance(xr) # WKV computation (parallelizable or sequential) wkv = self.wkv(k, v, state[:, :, :2]) # Apply receptance gate and output projection out = self.output(torch.sigmoid(r) * wkv) # Update state new_state = torch.stack([state_aa, state_ab, x[:, -1]], dim=2) return out, new_state def wkv(self, k, v, state): # Parallel implementation (training) # Sequential implementation (inference) - see below ... ``` ### WKV Parallel Algorithm (Training) ```python def wkv_forward(w, u, k, v): """ Parallel WKV computation for training. w: time_decay (d_model,) u: time_first (d_model,) k: keys (batch, seq_len, d_model) v: values (batch, seq_len, d_model) """ B, T, C = k.shape # Compute cumulative sums with exponential decay # This is the key to O(n) parallel computation w = -torch.exp(w) # Negative for decay # Associative scan operation wkv = torch.zeros(B, T, C, device=k.device) state = torch.zeros(B, C, device=k.device) for t in range(T): kv = k[:, t] * v[:, t] wkv[:, t] = (u * kv + state) / (u * k[:, t] + torch.exp(state_count)) state = w * state + kv return wkv ``` ### WKV Sequential Algorithm (Inference) ```python def wkv_inference(w, u, k, v, state): """ Sequential WKV for O(1) per-token inference. state: (aa, ab) from previous step """ w = -torch.exp(w) # time_decay u = torch.exp(u) # time_first # Unpack state aa, ab = state # aa = numerator, ab = denominator # Compute WKV for current token kv = k * v wkv = (u * kv + aa) / (u * k + ab) # Update state for next token new_aa = w * aa + kv new_ab = w * ab + k return wkv, (new_aa, new_ab) ``` ### Channel-Mixing Block Replaces Tra