
Model Pruning
Apply Wanda-style LLM pruning to cut model size and inference cost while keeping accuracy acceptable for solo-shipped AI features.
Install
npx skills add https://github.com/orchestra-research/ai-research-skills --skill model-pruningWhat is this skill?
- Wanda importance metric: |w_ij| × ||X_i|| blending weight magnitude and activation usage
- Documented target: ~50% sparsity with under 1% accuracy loss without retraining (per skill summary)
- One-shot pruning flow with calibration data on causal LM stacks
- Contrasts magnitude-only pruning with activation-aware decisions via concrete numeric examples
Adoption & trust: 1 installs on skills.sh; 9.4k GitHub stars; 2/3 security scanners passed (skills.sh audits).
Recommended Skills
Microsoft Foundrymicrosoft/azure-skills
Azure Aimicrosoft/azure-skills
Azure Hosted Copilot Sdkmicrosoft/azure-skills
Lark Eventlarksuite/cli
Running Claude Code Via Litellm Copilotxixu-me/skills
Setup Matt Pocock Skillsmattpocock/skills
Journey fit
Common Questions / FAQ
Is Model Pruning safe to install?
skills.sh reports 2 of 3 security scanners passed. Review the Security Audits panel on this page before installing in production.
SKILL.md
READMESKILL.md - Model Pruning
# Wanda: Pruning by Weights and Activations Based on ICLR 2024 paper (arXiv 2306.11695) - A Simple and Effective Pruning Approach for Large Language Models ## Overview **Source**: https://arxiv.org/abs/2306.11695 **Conference**: ICLR 2024 **GitHub**: https://github.com/locuslab/wanda Wanda prunes LLMs by weight magnitude × input activation, achieving 50% sparsity with <1% accuracy loss, no retraining required. ## Core Innovation ### Pruning Criterion **Key insight**: Weight importance = magnitude × usage ```python importance(w_ij) = |w_ij| × ||X_i|| where: - w_ij: Weight connecting input i to output j - X_i: Input activation norm for dimension i - ||·||: L2 norm ``` **Intuition**: - Large weight magnitude → important parameter - High activation → frequently used dimension - Product captures both factors ### Comparison with Magnitude Pruning **Magnitude pruning** (baseline): ```python importance = |weight| # Only considers weight size ``` **Wanda**: ```python importance = |weight| × activation # Considers usage too ``` **Example**: ``` Weight A: magnitude=0.5, activation=0.1 → importance=0.05 Weight B: magnitude=0.3, activation=0.8 → importance=0.24 Magnitude pruning: Keeps A (larger weight) Wanda: Keeps B (more important overall) ✓ ``` ## Algorithm ### One-Shot Pruning ```python import torch from transformers import AutoModelForCausalLM def wanda_prune(model, calib_data, sparsity=0.5): """ Wanda pruning algorithm. Steps: 1. Collect activation statistics on calibration data 2. Compute importance = |weight| × activation 3. Prune lowest importance weights 4. Return pruned model (no retraining!) """ # Step 1: Collect activations activations = {} def activation_hook(name): def hook(module, input, output): # Store input activation norms X = input[0].detach() # Per-input-dimension norm act_norm = X.abs().mean(dim=0) # Average over batch/sequence if name in activations: activations[name] += act_norm else: activations[name] = act_norm return hook # Register hooks hooks = [] for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): hook = module.register_forward_hook(activation_hook(name)) hooks.append(hook) # Run calibration model.eval() with torch.no_grad(): for batch in calib_data: model(**batch) # Remove hooks for hook in hooks: hook.remove() # Step 2 & 3: Prune based on importance for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and name in activations: W = module.weight.data act = activations[name] # Compute importance (per output dimension) importance = W.abs() * act.unsqueeze(0) # (out_features, in_features) # Find threshold for sparsity threshold = torch.quantile(importance.flatten(), sparsity) # Create mask mask = importance >= threshold # Apply pruning W.data *= mask.float() return model ``` ### Per-Output Pruning **Key detail**: Pruning is per-output dimension, not global. ```python # For each output dimension, prune sparsity% of weights for out_dim in range(out_features): # Importance for this output importance_out = |W[out_dim, :]| × activation # Prune sparsity% of this output's weights threshold = quantile(importance_out, sparsity) mask_out = importance_out >= threshold # Apply W[out_dim, :] *= mask_out ``` **Reason**: Ensures each output has similar capacity (balanced pruning). ## Calibration Data ### Requirements **Amount**: 128 samples (from paper) **Source**: Any text corpus (C4, WikiText, etc.) **Length**: 2048 tokens per sample ```python from datasets import load_dataset # Load calibration dataset