Chapter 5

Multistage Optimization Theory

Explore the theoretical foundations and mathematical frameworks for optimizing cascaded language model programs.

Introduction

Multi-stage optimization addresses the challenge of optimizing programs with multiple interconnected modules. Each stage's output becomes the input for subsequent stages, introducing dependencies and potential error propagation that simple optimization strategies often overlook.

Theoretical Foundations

The Optimization Problem

The goal is to find parameters $\theta$ that maximize the expected performance of a program $P$ composed of stages $f_1, ..., f_K$:

P(x; θ) = f_K(...f_1(x; θ_1)...; θ_K)

Optimization Landscapes

Optimizing these programs is difficult due to:

  • Non-convexity: Many local optima exist, especially with discrete prompt parameters.
  • Curse of Dimensionality: The search space grows exponentially with the number of stages.
  • Stage-wise Dependencies: Changes in early stages ripple through the entire pipeline.

Optimization Frameworks

1. Coordinate Descent

Optimizes one stage at a time while keeping others fixed. It is simple and effective but may get stuck in local optima.

def coordinate_descent(program, trainset):
    for round in range(num_rounds):
        for stage in program.stages:
            stage.optimize(trainset) # Optimize while others are frozen

2. Decomposition

Breaks the problem into independent subproblems. This simplifies optimization but ignores cross-stage interactions.

3. End-to-End Differentiable

Uses gradients to optimize all stages simultaneously. This requires differentiable components (e.g., soft prompts) but can find better global solutions.

Convergence & Emperical Analysis

Different algorithms offer different trade-offs in terms of convergence speed and solution quality.

Algorithm Convergence Rate Sample Efficiency
Coordinate Descent O(1/t) Low
Bayesian Optimization O(√(n)) High
Gradient-based O(1/t²) Highest

Case Study: HotpotQA

In Multi-hop Question Answering, joint optimization strategies like MIPRO significantly outperform independent stage optimization by accounting for error propagation.