Skip to main content

Dynamic Programming Patterns

DP is the largest topic in competitive programming. Beyond the basics, there are recurring patterns that unlock entire problem categories. This lesson covers the patterns that appear in intermediate and advanced contests.

Interval DP

State: dp[i][j] = optimal answer for the subproblem on the interval [i, j].

Classic problem — Matrix Chain Multiplication:

Given matrices M₁, M₂, ..., Mₙ, find the order of multiplication minimising total scalar multiplications.

def matrix_chain(dims):
    # dims[i-1] × dims[i] is the shape of matrix i
    n = len(dims) - 1
    dp = [[0] * n for _ in range(n)]

    for length in range(2, n + 1):          # interval length
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = float('inf')
            for k in range(i, j):           # split point
                cost = dp[i][k] + dp[k+1][j] + dims[i]*dims[k+1]*dims[j+1]
                dp[i][j] = min(dp[i][j], cost)

    return dp[0][n-1]

Template:

for length in range(2, n + 1):
    for i in range(n - length + 1):
        j = i + length - 1
        dp[i][j] = min/max over split point k in [i, j-1]:
            f(dp[i][k], dp[k+1][j], ...)

Other interval DP problems: Burst Balloons, Optimal BST, Stone Merging, Zuma game.

Bitmask DP

State: dp[mask] where mask is a bitmask representing a subset of elements.

Enumerate subsets in O(2ⁿ × n). Use when n ≤ 20 (sometimes up to 25 with pruning).

Classic — Travelling Salesman Problem (TSP):

Minimum cost Hamiltonian cycle on n cities.

def tsp(dist, n):
    INF = float('inf')
    dp = [[INF] * n for _ in range(1 << n)]
    dp[1][0] = 0   # start at city 0, visited = {0} = bitmask 1

    for mask in range(1, 1 << n):
        for u in range(n):
            if not (mask >> u & 1) or dp[mask][u] == INF:
                continue
            for v in range(n):
                if mask >> v & 1:
                    continue   # already visited
                new_mask = mask | (1 << v)
                dp[new_mask][v] = min(dp[new_mask][v], dp[mask][u] + dist[u][v])

    full = (1 << n) - 1
    return min(dp[full][u] + dist[u][0] for u in range(1, n))

Enumerating submasks of a mask:

submask = mask
while submask > 0:
    # process submask
    submask = (submask - 1) & mask   # next submask

This enumerates all 2^popcount(mask) submasks in O(2^popcount(mask)) — total across all masks is O(3ⁿ).

Other bitmask DP problems: minimum cost perfect matching, covering all elements with minimum cost subsets, assignment problems.

Digit DP

Count integers in [0, N] satisfying some property, by building the number digit by digit.

State: dp[pos][tight][...extra state...]

  • pos: current digit position
  • tight: are we still constrained by N's digits? (can't exceed N's digit at this position)
  • Extra state: whatever property we're tracking (digit sum, last digit, count of zeros, etc.)

Template:

def digit_dp(N):
    digits = list(map(int, str(N)))
    n = len(digits)
    from functools import lru_cache

    @lru_cache(maxsize=None)
    def dp(pos, tight, ...):
        if pos == n:
            return 1 if <condition satisfied> else 0
        limit = digits[pos] if tight else 9
        result = 0
        for d in range(0, limit + 1):
            new_tight = tight and (d == limit)
            result += dp(pos + 1, new_tight, ...)
        return result

    return dp(0, True, ...)

Example — count numbers in [1, N] whose digit sum is divisible by k:

def count_digit_sum_div_k(N, k):
    digits = list(map(int, str(N)))
    n = len(digits)

    @lru_cache(maxsize=None)
    def dp(pos, tight, remainder):
        if pos == n:
            return 1 if remainder == 0 else 0
        limit = digits[pos] if tight else 9
        result = 0
        for d in range(limit + 1):
            result += dp(pos+1, tight and d == limit, (remainder + d) % k)
        return result

    return dp(0, True, 0)

DP with Data Structures

Standard 1D DP transitions are O(n) per state → O(n²) total. Data structures reduce transitions to O(log n).

Longest Increasing Subsequence (LIS) — O(n log n):

import bisect

def lis(arr):
    tails = []   # tails[i] = smallest tail of IS of length i+1
    for x in arr:
        pos = bisect.bisect_left(tails, x)
        if pos == len(tails):
            tails.append(x)
        else:
            tails[pos] = x
    return len(tails)

DP with segment tree — range maximum query in transitions:

When dp[i] = max(dp[j]) + cost(j, i) for j < i with some constraint, a segment tree enables O(log n) range maximum queries.

# Segment tree max query
class SegTree:
    def __init__(self, n):
        self.n = n
        self.tree = [-float('inf')] * (2 * n)

    def update(self, i, val):
        i += self.n
        self.tree[i] = val
        while i > 1:
            i //= 2
            self.tree[i] = max(self.tree[2*i], self.tree[2*i+1])

    def query(self, l, r):   # [l, r)
        res = -float('inf')
        l += self.n; r += self.n
        while l < r:
            if l & 1: res = max(res, self.tree[l]); l += 1
            if r & 1: r -= 1; res = max(res, self.tree[r])
            l //= 2; r //= 2
        return res

Divide and Conquer Optimisation

When the DP recurrence has the form:

dp[i][j] = min over k in [opt[i][j-1], opt[i+1][j]] of (dp[i-1][k] + cost(k+1, j))

and the optimal split point is monotone (opt[i][j] ≤ opt[i][j+1]), the total transitions reduce from O(n²) to O(n log n).

def solve(lo, hi, opt_lo, opt_hi, prev_dp, curr_dp, cost):
    if lo > hi: return
    mid = (lo + hi) // 2
    best_k = opt_lo
    best_val = float('inf')
    for k in range(opt_lo, min(opt_hi, mid) + 1):
        val = prev_dp[k] + cost(k + 1, mid)
        if val < best_val:
            best_val = val
            best_k = k
    curr_dp[mid] = best_val
    solve(lo, mid - 1, opt_lo, best_k, prev_dp, curr_dp, cost)
    solve(mid + 1, hi, best_k, opt_hi, prev_dp, curr_dp, cost)

Convex Hull Trick (CHT)

When the DP has the form dp[i] = min(dp[j] + b[j] * a[i]) — minimising over linear functions — CHT reduces O(n²) transitions to O(n) (if queries are sorted) or O(n log n) (general).

The key insight: f_j(x) = b[j] * x + dp[j] is a family of lines. For each query x = a[i], we want the minimum over all lines. The lower envelope of a set of lines forms a convex hull — hence the name.

from collections import deque

class CHT:
    def __init__(self):
        self.lines = deque()   # (slope, intercept)

    def add(self, m, b):
        line = (m, b)
        while len(self.lines) >= 2:
            l1, l2 = self.lines[-2], self.lines[-1]
            # Remove l2 if it's never optimal
            if self.intersect(l1, line) <= self.intersect(l1, l2):
                self.lines.pop()
            else:
                break
        self.lines.append(line)

    def query(self, x):
        while len(self.lines) > 1 and \
              self.eval(self.lines[0], x) >= self.eval(self.lines[1], x):
            self.lines.popleft()
        return self.eval(self.lines[0], x)

    def eval(self, line, x):
        return line[0] * x + line[1]

    def intersect(self, l1, l2):
        return (l2[1] - l1[1]) / (l1[0] - l2[0])

Key Takeaways

  • Interval DP: dp[i][j] over subintervals; fill by increasing length; split point k iterates [i, j-1]. O(n³).
  • Bitmask DP: dp[mask] for subset-based problems. O(2ⁿ × n). Use submask enumeration in O(3ⁿ).
  • Digit DP: count integers with a property by building digit by digit, tracking tight and extra state. Implement with top-down memoisation.
  • LIS in O(n log n) using binary search on a tails array.
  • Divide and conquer optimisation: O(n log n) when optimal split point is monotone.
  • Convex Hull Trick: O(n) or O(n log n) for DP of the form dp[i] = min(dp[j] + b[j]*a[i]).