Chapter 5

Joint Optimization

Unleash the power of synergy by optimizing model weights and prompts simultaneously.

Introduction

Joint optimization in DSPy represents a paradigm shift from treating fine-tuning and prompt optimization as separate processes. Instead, it recognizes that these two optimization dimensions are deeply interconnected and can be optimized together to achieve superior performance.

Theoretical Foundations

Why Joint Optimization Matters

Traditional sequential optimization (fine-tune then prompt-optimize) often gets stuck in suboptimal local minima. Joint optimization addresses this by:

  • Simultaneous Exploration: Exploring the combined space of parameters and prompts.
  • Coordinated Updates: Ensuring parameter and prompt updates complement each other.
  • Global Optimum Seeking: Working toward a true global optimum across both dimensions.

Mathematical Framework

The objective is to maximize the joint likelihood:

L(θ, p) = Σ_i log P(y_i | x_i; θ, p) + λ1 * R1(θ) + λ2 * R2(p)

Where θ represents model parameters and p represent prompts.

Joint Optimization Strategies

1. Alternating Optimization

The most common approach where parameters and prompts are optimized in alternating phases:

class AlternatingJointOptimizer(JointOptimizationFramework):
    def optimize(self, train_data, val_data, num_epochs=10):
        for epoch in range(num_epochs):
            # Phase 1: Parameter optimization
            self._optimize_parameters(train_data, val_data, steps=5)

            # Phase 2: Prompt optimization
            self._optimize_prompts(train_data, val_data, steps=1)

            # Evaluate combined performance
            combined_metric = self._evaluate(val_data)

2. Simultaneous Gradient-Based Optimization

For soft prompts that can be optimized with gradients alongside model weights:

class SimultaneousJointOptimizer(JointOptimizationFramework):
    def optimize(self, train_data, val_data):
        # Forward pass with both parameter and prompt gradients
        outputs = self.forward(batch)
        loss = self.compute_joint_loss(outputs, batch)

        # Backward pass
        self.param_optimizer.zero_grad()
        self.prompt_optimizer.zero_grad()
        loss.backward()

        # Update both
        self.param_optimizer.step()
        self.prompt_optimizer.step()

This is Experimental

It is important to note that full joint optimization (simultaneous updates) is an advanced and often experimental technique. In many practical DSPy workflows, alternating or "Coordinate Descent" style optimization (like COPA) is more stable and easier to implement.

DSPy Joint Optimizer Example

class DSPyJointOptimizer(dspy.Module):
    def optimize(self, trainset, valset, metric=None):
        # Initialize optimization state
        state = OptimizationState(
            model=self.base_model,
            prompts=self._initialize_prompts(),
            trainset=trainset
        )

        # Run optimization
        best_state = self.coordinator.optimize(state)
        return best_state.model, best_state.prompts

Advanced Techniques

  • Curriculum Joint Optimization: Gradually increasing the complexity of data or the "freedom" of the optimization parameters over time.
  • Meta-Learning: Using meta-learning to find good initializations for both prompts and weights that adapt quickly to new tasks.