# Piecewise Interpolation in TensorFlow

January 29, 2022

NumPy’s `interp`

is a handy function for generating an array from a piecewise linear
mapping defined by a set of control points. For example, here is a linear plot
of today’s U.S. Treasury yield curve:

```
xs = [1, 2, 3, 6, 12, 24, 36, 60, 84, 120, 240, 360]
ys = [0.05, 0.09, 0.19, 0.39, 0.58, 0.99, 1.25, 1.53, 1.69, 1.75, 2.15, 2.10]
x = np.linspace(min(xs), max(xs), 100)
y = np.interp(x, xs, ys)
```

This function can be useful in ML applications for thresholding ROC curves,
decaying learning rates, and performing other linear calculations.
Unfortunately, there is currently no vectorized implementation of
`np.interp`

in TensorFlow or PyTorch, but it is straightforward to implement
a piecewise interpolator using existing primitives.

### Vectorized Interpolation

For eager-mode TensorFlow/PyTorch applications, it is easy enough to call
`np.interp`

directly, and NumPy will do the right thing with tensor inputs.
However, these calls cannot be made in graph contexts, such as `tf.function`

or TorchScript. It is also possible to use `tf.numpy_function`

to wrap
`np.interp`

, but doing so limits parallelism, since the TensorFlow runtime
must call back into Python.

Instead, the following function can be used as a fully-vectorized analog of
`np.interp`

for TensorFlow that is compatible with graph mode.

```
def tf_interp(x: typing.Any, xs: typing.Any, ys: typing.Any) -> tf.Tensor:
# determine the output data type
ys = tf.convert_to_tensor(ys)
dtype = ys.dtype
# normalize data types
ys = tf.cast(ys, tf.float64)
xs = tf.cast(xs, tf.float64)
x = tf.cast(x, tf.float64)
# pad control points for extrapolation
xs = tf.concat([[xs.dtype.min], xs, [xs.dtype.max]], axis=0)
ys = tf.concat([ys[:1], ys, ys[-1:]], axis=0)
# compute slopes, pad at the edges to flatten
ms = (ys[1:] - ys[:-1]) / (xs[1:] - xs[:-1])
ms = tf.pad(ms[:-1], [(1, 1)])
# solve for intercepts
bs = ys - ms*xs
# search for the line parameters at each input data point
# create a grid of the inputs and piece breakpoints for thresholding
# rely on argmax stopping on the first true when there are duplicates,
# which gives us an index into the parameter vectors
i = tf.math.argmax(xs[..., tf.newaxis, :] > x[..., tf.newaxis], axis=-1)
m = tf.gather(ms, i, axis=-1)
b = tf.gather(bs, i, axis=-1)
# apply the linear mapping at each input data point
y = m*x + b
return tf.cast(tf.reshape(y, tf.shape(x)), dtype)
```

This function first converts all inputs to `float64`

(like NumPy), and then
pads the control points to cover the whole domain of the mapping. The
slope/intercept values of the linear parts are then solved directly, using the
familiar `y = mx + b`

equation.

At this point, the implementation diverges from the NumPy version. For each
input `x`

, NumPy locates the corresponding line equation using a binary
search. In the vectorized implementation, we construct the full grid of
`[xs, x]`

values and locate the index of the first value in `xs`

that is
greater than the input x using the `argmax`

function. After that, we simply
apply the linear equation at that index to compute the result.

To see how this works, consider a lookup of the `x`

vector `[5, 65]`

in the
treasury yield example from above. The `xs[..., tf.newaxis, :] > x[..., tf.newaxis]`

comparison broadcasts the vectors into the following Boolean matrix:

1 | 2 | 3 | 6 | 12 | 24 | 36 | 60 | 84 | 120 | 240 | 360 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|

5 | 0 | 0 | 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |

65 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 1 | 1 |

The `argmax`

function has a convention to return the index of the first match
if the maximum value is duplicated in the array, which results in an index of
3 for an input of 5 and 8 for 65. These indexes are then used to select the
appropriate linear parameters for interpolation.

The tradeoff in this approach is that it creates a tensor containing `size(x) * size(xs)`

elements to perform the lookup, instead of using NumPy’s CPU-parallel binary
search. This intermediate tensor can become large if there are a lot of
points to interpolate or control points. However, the lookup is performed in
effectively constant time on a GPU, which can be much faster than NumPy when
there are a lot of points to interpolate.

#### Testing

How do we know it works? Since we have a baseline implementation in NumPy, we can simply compute our outputs for the same inputs and compare. The Hypothesis property testing framework lets us do this without hard-coding any example arrays.

```
@hypothesis.given(
x=hypnum.arrays(
dtype=hypnum.floating_dtypes(),
shape=hypnum.array_shapes(),
elements=dict(min_value=-1, max_value=1)
),
xs=hypnum.arrays(
dtype=hypnum.floating_dtypes(),
shape=hypnum.array_shapes(max_dims=1),
unique=True,
elements=dict(min_value=-1, max_value=1)
),
ys=hypnum.arrays(
dtype=hypnum.floating_dtypes(),
shape=hypnum.array_shapes(max_dims=1),
elements=dict(min_value=-1, max_value=1)
),
)
def test_interp(x, xs, ys):
xs, ys = xs[:len(ys)], ys[:len(xs)]
xs = np.sort(xs)
expect = np.interp(x, xs, ys).astype(ys.dtype)
actual = tf_interp(x, xs, ys).numpy()
assert np.allclose(expect, actual), f"expect={expect} actual={actual}"
test_interp()
```

The test first sorts the interpolation domain (`xs`

) and ensures it contains
unique values, which is an assumption made by all of these `interp`

functions.
We also restrict the ranges of all values to `[-1, 1]`

, in order to avoid
rounding/overflow inconsistencies between the frameworks.

### Application: Piecewise Learning Rate Schedule

As an example, we’ll implement a piecewise linear learning rate decay
for TensorFlow. This learning rate scheduler can be used with any `tf.keras`

optimizer. It is configured with a sorted list of training steps and
corresponding learning rates. The learning rate returned to the optimizer will
be interpolated based on this schedule.

```
class PiecewiseLinearSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, steps: typing.List[int], rates: typing.List[int]):
if steps != sorted(steps):
raise ValueError("steps should be listed in ascending order")
self._steps = steps
self._rates = rates
def __call__(self, step: tf.Tensor) -> tf.Tensor:
return tf_interp(step, self._steps, self._rates)
learning_rate = PiecewiseLinearSchedule([100, 500, 1000], [3e-4, 5e-5, 1e-5])
```

For an example usage, we’ll train a simple DNN logistic regression model on
the XOR function, a classic “Hello World” problem for neural networks. Note
that the training step function `train`

uses the `tf.function`

decorator for
graph mode training, which allows us to take advantage of the `tf_interp`

function we implemented above.

```
model = tf.keras.Sequential(
[
tf.keras.layers.Dense(units=256, activation="relu"),
tf.keras.layers.Dense(units=1)
]
)
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
@tf.function
def train(inputs: tf.Tensor, targets: tf.Tensor) -> None:
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
loss = criterion(targets, outputs)
optimizer.minimize(loss, model.trainable_variables, tape=tape)
# fit on random batches of binary inputs and XOR targets
for _ in range(1000):
x = tf.random.uniform([32, 2]) < 0.5
y = x[:, :1] ^ x[:, 1:]
train(x, y)
```

We can then evaluate the trained model on the truth table to verify it has learned XOR.

```
x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = tf.nn.sigmoid(model(x)).numpy()[..., 0]
for (a, b), c in zip(x, y):
print(f"{a:^} ⊕ {b:^} = {c > 0.5:^}")
```

```
0 ⊕ 0 = 0
0 ⊕ 1 = 1
1 ⊕ 0 = 1
1 ⊕ 1 = 0
```

### Application: ADSR Envelope Generator

Piecewise linear functions are also commonly used to model ADSR envelopes for music synthesis. For example, an envelope can be used to control the timbre of an instrument as keys are pressed, held, and released, by shaping the amplitude of the audio signal.

The **attack**, **decay**, and **release** parameters specify the amount of
time for the envelope to transition between stages, corresponding to the
`xs`

parameter in the interpolator. The **sustain** parameter indicates
the gain that the envelope should decay to and hold while the instrument’s
key remains pressed.

The envelope generator below demonstrates how the ADSR parameters correspond to piecewise interpolator control points. The envelope begins at 0 gain, rises through the attack stage to full gain, decays to the sustain level, and then decays back to 0 in the release stage. All time values are represented in seconds.

```
def adsr(attack, decay, sustain, release):
def gen(time):
hold = max(time[-1] - (attack + decay + release), 0)
xs = np.cumsum([0, attack, decay, hold, release])
ys = [0, 1, sustain, sustain, 0]
return tf_interp(time, xs, ys)
return gen
envelope = adsr(attack=0.1, decay=0.05, sustain=0.6, release=0.15)
```

We can now use this envelope to shape some audio samples. To synthesize a
note, we first generate a triangle wave (using `scipy.signal.sawtooth`

with
a `width`

of 0.5) of the desired frequency and duration. Then we just multiply
the signal by the envelope to modulate its amplitude.

```
# audio sample rate
FS = 44100
def note(frequency, duration, envelope):
t = np.linspace(0, duration, int(duration*FS))
e = envelope(t)
y = scipy.signal.sawtooth(2*π*frequency*t, width=0.5)
return e * y
def rest(duration):
return np.zeros(int(duration*FS))
Audio(note(261.63, 0.5, envelope), rate=44100)
```

Now we can play a simple melody by concatenating notes/rests into a vector
of audio samples. In the non-enveloped clip below, you can hear the clicks
between notes where the signal boundaries are discontinuous. The **attack**
and **release** stages fade smoothly between notes in the enveloped clip.

```
# quarter note duration
QUARTER = 0.3
# note frequencies
REST = 0
G3 = 196.0
A3 = 220.0
B3 = 246.94
C4 = 261.63
D4 = 293.66
E4 = 329.63
F4 = 349.23
G4 = 392.00
A4 = 440.00
B4 = 493.88
C5 = 523.25
# a simple tune
NOTES = [
(A3, 1), (C4, 1), (E4, 1), (REST, 1),
(D4, 1), (C4, 1), (D4, 1), (E4, 1), (REST, 1),
(E4, 1), (D4, 1), (C4, 1), (B3, 1), (C4, 2), (A3, 1), (REST, 1),
(E4, 1), (D4, 1), (C4, 1), (B3, 1), (C4, 2), (A3, 1), (REST, 1),
(G3, 1), (A3, 1), (C4, 1)
]
def melody(notes, envelope):
ys = [note(f, d*QUARTER, envelope) if f != REST
else rest(d*QUARTER)
for f, d in NOTES]
return np.hstack(ys)
# a no-op rectangular envelope
y = melody(NOTES, lambda _: 1)
display(Audio(y, rate=FS))
# ADSR envelope from above
y = melody(NOTES, envelope)
display(Audio(y, rate=FS))
```

If you have any feedback or comments, please reply to this post on Twitter.