This holiday break gives me time to dig into the music transcription models that we trained and used this year, and figure out what those ~6,000,000 parameters actually mean. In particular, I’ll be looking at the “version 2” model, which generated all of the 30,000 transcriptions compiled in ten volumes, as well as helped compose a large amount of new music, including such greats as the four movements of Oded Ben-Tal’s fantastic “Bastard Tunes”:

and X:7153 harmonised by DeepBach and played on the St. Dunstan and All Saints church organ by Richard Salmon

Last time we looked at the internal algorithms of this model (the three LSTM layers and the final softmax layer), and its training. Now it’s time to dig deeper.

The most immediate thing we can do is analyse the parameters of the layers closest to the token vocabulary. Let’s start with the parameters of the softmax output layer, i.e., the 137×512 matrix and the 137-dimensional vector . These produce by a linear combination of the columns of and the bias what we treat as a multinomial probability distribution over the vocabulary. The last LSTM layer gives the weights for this linear combination, i.e., . Hence, we can interpret the columns of and as “shapes” contributing to the probability distribution generated by the model. (The softmax merely compresses this shape since it is a monotonically increasing function.) Each dimension of a column corresponds to the probability of a particular token. Increase the value in that dimension and we increase its probability; decrease the value and we decrease its probability. Below is a plot of all columns of (black) and the bias (red).

We can see some interesting things. If the input to the softmax layer were all zeros the resulting shape would look like the red line. We see low values for the tokens “<s>”, “=f'”, “16”, “(7”, “^f'” and “=C,”. These tokens are rarely generated by the model, and in fact are rare in the training data — with the exception of “<s>” which is always the first token input to the model (the v2 model never generates this token). In the training data with 4,032,490 tokens, we find “(7” appears only once, “=f'” and “^f'” appear only twice each, “16” appers only five times, and “=C,” appears only six times. The largest value of the bias corresponds to the token “C”, which is sensible since we have transposed all 23,000 transcriptions in the training data to modes with the root C. “C” appears 191,424 times in the training data.

Another interesting thing to observe is how the columns of together resemble a landscape viewed over water with a mean of 0. In fact, the wide matrix has full rank, which means its null space has a non-zero dimension (375), and so there are ways to sum the columns to create a zero vector, or even to completely cancel the bias . With full rank, we know that our model can potentially reach any point on the positive face of the -unit ball of $\latex \mathbb{R}^{137}$. This would not be possible if the number of output units of the last LSTM layer was smaller than 137.

How do all the columns of relate to one another? We can get an idea by visualising the Gramian of the matrix, i.e., , where . (The notation means the th column of .) Below we see the Gramian of , plotted with reference to the units of the last LSTM layer.We see a lot of dark red and dark blue, which means the units in the last LSTM layer resemble the unfourtantely high partisanship in the USA. No, really, it means that many of the columns in point in nearly the same directions (up to a change of sign); but there are some vectors that point in more unique ways. Here’s a scatter plot of all the values in the image above sorted according to the variance we see along each row (or column equivalently).

That looks exactly like an xray of an empty burrito. Let’s zoom in on those units weighting columns that are quite different from nearly everything else in :

The output units 175, 216 and 497 of the last LSMT layer are contributing in unique ways to the distribution . Let’s look at the corresponding columns and see if we can’t divine some interpretation.

It looks like column 175 points most strongly in the direction of tokens “]” and “5”, in the opposite direction of “9” and “(2”. Column 216 points most strongly in the direciton of tokens “=c” and “/2>”, and in the opposite direction of tokens “^g” and “_B”. And column 497 points most strongly in the direction of tokens “=B,”, “_C” and “_c”, and in the opposite direction of tokens “_A,” and “^g”. Taken together with the above, this shows that the folk-rnn v2 model can make “]” and “5” more probable and “9” and “(2” less probable by increasing the output of unit 175 of the last LSTM layer. That’s not so meaningful. However, increasing the output of unit 497 will increase the probability of “=B,”, “_C” and “_c” and decrease the probability of “_A,” and “^g”. What’s very interesting here is that “=B,”, “_C” and “_c” are the same pitch class, as are “_A,” and “^g”. Unit 497 is treating the probabilities of the tokens in those two groups in similar ways. Has this unit learned about the relationship between these tokens?

Since these values are additive inside the softmax, we can derive what they mean in terms of the change in probability. For the th token with pre-softmax value displaced by some , we want to find in the following:

where , and is the th standard basis vector of . For the above to satisfy the axioms of probability, we must restrict the value of : .

After a bit of algebra, we find the following nice relationship:

or

.

As , the probability of the token does not change. As , the probability of the token goes to zero. And as . Also, as , then .

Applying this to the above, we can see that if the last LSTM layer unit weighting column 497 of were to spit out a +1, then, all other things being equal, it would increase the probability of “=B,” by a factor of about 4 if it were initially about 5% probable. If it spit out a -1, then, all other things being equal, it would decrease the probability of “=B,” by a factor of about 1/4 if it were about 5% probable.

Returning to the scatter plot of the Gramian values above, we see that some columns of are almost identical. The largest value we see is 0.9968 (columns 187 and 2) and the smallest is -0.9937 (187 and 361). Here are the six columns of that point most in the same directions (up to a scalar):

We get the sense here that units 2, 187 and 361 represent nearly the same information, as do units 34, 156 and 203. Interestingly, the columns 2, 187 and 361 don’t seem interested in three octaves of pitch tokens from “A,” to “a”. Furthermore, they don’t seem to greatly change the probability of any specific tokens. However, this is not the case. Here’s the normalised vectors resulting from the differences between four of these normalised vectors.

Units 2, 187 and 361 may be weighting columns that point in nearly the same direction, but this shows their sums and differences points strongly in the direction of increasing the probability of some duration tokens “/2”, “2” and “>”, and decreasing the probability of other duration tokens “4” and “6”. The sums and differences of the other three columns do not show such a strong tendency for particular tokens, but we do see their differences point in the direction of mode tokens “K:Cmaj”, “K:Cmin”, “K:Cdor” and “K:Cmix”.

Anyhow, this shows how two units may be combining vectors that point in nearly the same direction (thanks to 137-dimensional space), but their combination can produce marked effects in specific directions. It is important to interpret the columns of , and likewise the units of the last LSTM layer, as a group and not as individual embodiments of our vocabulary.

Stay tuned!

Pingback: Making sense of the folk-rnn v2 model, part 3 | High Noon GMT

Pingback: Making sense of the folk-rnn v2 model, part 4 | High Noon GMT

Pingback: Making sense of the folk-rnn v2 model, part 5 | High Noon GMT

Pingback: Making sense of the folk-rnn v2 model, part 6 | High Noon GMT

Pingback: Making sense of the folk-rnn v2 model, part 7 | High Noon GMT

Pingback: Making sense of the folk-rnn v2 model, part 8 | High Noon GMT