Piecewise Interpolation in TensorFlow

January 29, 2022

NumPy's interp is a handy function for generating an array from a piecewise linear mapping defined by a set of control points. For example, here is a linear plot of today's U.S. Treasury yield curve:

import typing

import hypothesis
import numpy as np
import scipy.signal
import tensorflow as tf
from hypothesis.extra import numpy as hypnum
from IPython.display import Audio, display
from matplotlib import pyplot as plt
from tabulate import tabulate

π = np.pi
xs = [1, 2, 3, 6, 12, 24, 36, 60, 84, 120, 240, 360]
ys = [0.05, 0.09, 0.19, 0.39, 0.58, 0.99, 1.25, 1.53, 1.69, 1.75, 2.15, 2.10]
x = np.linspace(min(xs), max(xs), 100)
y = np.interp(x, xs, ys)

This function can be useful in ML applications for thresholding ROC curves, decaying learning rates, and performing other linear calculations. Unfortunately, there is currently no vectorized implementation of np.interp in TensorFlow or PyTorch, but it is straightforward to implement a piecewise interpolator using existing primitives.

Vectorized Interpolation

For eager-mode TensorFlow/PyTorch applications, it is easy enough to call np.interp directly, and NumPy will do the right thing with tensor inputs. However, these calls cannot be made in graph contexts, such as tf.function or TorchScript. It is also possible to use tf.numpy_function to wrap np.interp, but doing so limits parallelism, since the TensorFlow runtime must call back into Python.

Instead, the following function can be used as a fully-vectorized analog of np.interp for TensorFlow that is compatible with graph mode.

def tf_interp(x: typing.Any, xs: typing.Any, ys: typing.Any) -> tf.Tensor:
    # determine the output data type
    ys = tf.convert_to_tensor(ys)
    dtype = ys.dtype
    # normalize data types
    ys = tf.cast(ys, tf.float64)
    xs = tf.cast(xs, tf.float64)
    x = tf.cast(x, tf.float64)

    # pad control points for extrapolation
    xs = tf.concat([[xs.dtype.min], xs, [xs.dtype.max]], axis=0)
    ys = tf.concat([ys[:1], ys, ys[-1:]], axis=0)

    # compute slopes, pad at the edges to flatten
    ms = (ys[1:] - ys[:-1]) / (xs[1:] - xs[:-1])
    ms = tf.pad(ms[:-1], [(1, 1)])

    # solve for intercepts
    bs = ys - ms*xs

    # search for the line parameters at each input data point
    # create a grid of the inputs and piece breakpoints for thresholding
    # rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors
    i = tf.math.argmax(xs[..., tf.newaxis, :] > x[..., tf.newaxis], axis=-1)
    m = tf.gather(ms, i, axis=-1)
    b = tf.gather(bs, i, axis=-1)

    # apply the linear mapping at each input data point
    y = m*x + b
    return tf.cast(tf.reshape(y, tf.shape(x)), dtype)

This function first converts all inputs to float64 (like NumPy), and then pads the control points to cover the whole domain of the mapping. The slope/intercept values of the linear parts are then solved directly, using the familiar y = mx + b equation.

At this point, the implementation diverges from the NumPy version. For each input x, NumPy locates the corresponding line equation using a binary search. In the vectorized implementation, we construct the full grid of [xs, x] values and locate the index of the first value in xs that is greater than the input x using the argmax function. After that, we simply apply the linear equation at that index to compute the result.

To see how this works, consider a lookup of the x vector [5, 65] in the treasury yield example from above. The xs[..., tf.newaxis, :] > x[..., tf.newaxis] comparison broadcasts the vectors into the following Boolean matrix:

1 2 3 6 12 24 36 60 84 120 240 360
5 0 0 0 1 1 1 1 1 1 1 1 1
65 0 0 0 0 0 0 0 0 1 1 1 1

The argmax function has a convention to return the index of the first match if the maximum value is duplicated in the array, which results in an index of 3 for an input of 5 and 8 for 65. These indexes are then used to select the appropriate linear parameters for interpolation.

The tradeoff in this approach is that it creates a tensor containing size(x) * size(xs) elements to perform the lookup, instead of using NumPy's CPU-parallel binary search. This intermediate tensor can become large if there are a lot of points to interpolate or control points. However, the lookup is performed in effectively constant time on a GPU, which can be much faster than NumPy when there are a lot of points to interpolate.


How do we know it works? Since we have a baseline implementation in NumPy, we can simply compute our outputs for the same inputs and compare. The Hypothesis property testing framework lets us do this without hard-coding any example arrays.

        elements=dict(min_value=-1, max_value=1)
        elements=dict(min_value=-1, max_value=1)
        elements=dict(min_value=-1, max_value=1)
def test_interp(x, xs, ys):
    xs, ys = xs[:len(ys)], ys[:len(xs)]
    xs = np.sort(xs)
    expect = np.interp(x, xs, ys).astype(ys.dtype)
    actual = tf_interp(x, xs, ys).numpy()
    assert np.allclose(expect, actual), f"expect={expect} actual={actual}"


The test first sorts the interpolation domain (xs) and ensures it contains unique values, which is an assumption made by all of these interp functions. We also restrict the ranges of all values to [-1, 1], in order to avoid rounding/overflow inconsistencies between the frameworks.

Application: Piecewise Learning Rate Schedule

As an example, we'll implement a piecewise linear learning rate decay for TensorFlow. This learning rate scheduler can be used with any tf.keras optimizer. It is configured with a sorted list of training steps and corresponding learning rates. The learning rate returned to the optimizer will be interpolated based on this schedule.

class PiecewiseLinearSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, steps: typing.List[int], rates: typing.List[int]):
        if steps != sorted(steps):
            raise ValueError("steps should be listed in ascending order")
        self._steps = steps
        self._rates = rates
    def __call__(self, step: tf.Tensor) -> tf.Tensor:
        return tf_interp(step, self._steps, self._rates)

learning_rate = PiecewiseLinearSchedule([100, 500, 1000], [3e-4, 5e-5, 1e-5])

For an example usage, we'll train a simple DNN logistic regression model on the XOR function, a classic "Hello World" problem for neural networks. Note that the training step function train uses the tf.function decorator for graph mode training, which allows us to take advantage of the tf_interp function we implemented above.

model = tf.keras.Sequential(
        tf.keras.layers.Dense(units=256, activation="relu"),
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

def train(inputs: tf.Tensor, targets: tf.Tensor) -> None:
    with tf.GradientTape() as tape:
        outputs = model(inputs, training=True)
        loss = criterion(targets, outputs)
    optimizer.minimize(loss, model.trainable_variables, tape=tape)

# fit on random batches of binary inputs and XOR targets
for _ in range(1000):
    x = tf.random.uniform([32, 2]) < 0.5
    y = x[:, :1] ^ x[:, 1:]
    train(x, y)

We can then evaluate the trained model on the truth table to verify it has learned XOR.

x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = tf.nn.sigmoid(model(x)).numpy()[..., 0]

for (a, b), c in zip(x, y):
    print(f"{a:^} ⊕ {b:^} = {c > 0.5:^}")
0 ⊕ 0 = 0
0 ⊕ 1 = 1
1 ⊕ 0 = 1
1 ⊕ 1 = 0

Application: ADSR Envelope Generator

Piecewise linear functions are also commonly used to model ADSR envelopes for music synthesis. For example, an envelope can be used to control the timbre of an instrument as keys are pressed, held, and released, by shaping the amplitude of the audio signal.

The attack, decay, and release parameters specify the amount of time for the envelope to transition between stages, corresponding to the xs parameter in the interpolator. The sustain parameter indicates the gain that the envelope should decay to and hold while the instrument's key remains pressed.

The envelope generator below demonstrates how the ADSR parameters correspond to piecewise interpolator control points. The envelope begins at 0 gain, rises through the attack stage to full gain, decays to the sustain level, and then decays back to 0 in the release stage. All time values are represented in seconds.

def adsr(attack, decay, sustain, release):
    def gen(time):
        hold = max(time[-1] - (attack + decay + release), 0)
        xs = np.cumsum([0, attack, decay, hold, release])
        ys = [0, 1, sustain, sustain, 0]
        return tf_interp(time, xs, ys)
    return gen

envelope = adsr(attack=0.1, decay=0.05, sustain=0.6, release=0.15)

We can now use this envelope to shape some audio samples. To synthesize a note, we first generate a triangle wave (using scipy.signal.sawtooth with a width of 0.5) of the desired frequency and duration. Then we just multiply the signal by the envelope to modulate its amplitude.

# audio sample rate
FS = 44100

def note(frequency, duration, envelope):
    t = np.linspace(0, duration, int(duration*FS))
    e = envelope(t)
    y = scipy.signal.sawtooth(2*π*frequency*t, width=0.5)
    return e * y

def rest(duration):
    return np.zeros(int(duration*FS))

Audio(note(261.63, 0.5, envelope), rate=44100)

Now we can play a simple melody by concatenating notes/rests into a vector of audio samples. In the non-enveloped clip below, you can hear the clicks between notes where the signal boundaries are discontinuous. The attack and release stages fade smoothly between notes in the enveloped clip.

# quarter note duration

# note frequencies
REST = 0
G3 = 196.0
A3 = 220.0
B3 = 246.94
C4 = 261.63
D4 = 293.66
E4 = 329.63
F4 = 349.23
G4 = 392.00
A4 = 440.00
B4 = 493.88
C5 = 523.25

# a simple tune
    (A3, 1), (C4, 1), (E4, 1), (REST, 1), 
    (D4, 1), (C4, 1), (D4, 1), (E4, 1), (REST, 1),
    (E4, 1), (D4, 1), (C4, 1), (B3, 1), (C4, 2), (A3, 1), (REST, 1),
    (E4, 1), (D4, 1), (C4, 1), (B3, 1), (C4, 2), (A3, 1), (REST, 1),
    (G3, 1), (A3, 1), (C4, 1)

def melody(notes, envelope):
    ys = [note(f, d*QUARTER, envelope) if f != REST 
          else rest(d*QUARTER)
          for f, d in NOTES]
    return np.hstack(ys)

# a no-op rectangular envelope
y = melody(NOTES, lambda _: 1)
display(Audio(y, rate=FS))

# ADSR envelope from above
y = melody(NOTES, envelope)
display(Audio(y, rate=FS))

If you have any feedback or comments, please reply to this post on Twitter.