{"id":2003,"date":"2024-04-01T03:51:03","date_gmt":"2024-04-01T03:51:03","guid":{"rendered":"https:\/\/timallanwheeler.com\/blog\/?p=2003"},"modified":"2024-04-01T03:51:03","modified_gmt":"2024-04-01T03:51:03","slug":"transformers-how-and-why-they-work","status":"publish","type":"post","link":"https:\/\/timallanwheeler.com\/blog\/2024\/04\/01\/transformers-how-and-why-they-work\/","title":{"rendered":"Transformers &#8211; How and Why They Work"},"content":{"rendered":"\n<p>This month I decided to take a break from my sidescroller project and instead properly attend to transformers (pun intended). Unless you&#8217;ve been living under a rock, you&#8217;ve noticed the rapid advanced of AI in the last year and the advent of extremely large models like <a href=\"https:\/\/chat.openai.com\/\">ChatGPT 4<\/a> and <a href=\"https:\/\/www.googleadservices.com\/pagead\/aclk?sa=L&amp;ai=DChcSEwjx6rSio_eEAxVSzsIEHWc6D30YABAAGgJwdg&amp;ase=2&amp;gclid=CjwKCAjw48-vBhBbEiwAzqrZVCqklDZR3iBj4UJ5aYF1-FtZdqTaMX0iAr2fs7bqxymlmnQSXJLXExoChL0QAvD_BwE&amp;ohost=www.google.com&amp;cid=CAESVuD2_YjX32mQRalqZoz4AM90aJSbf_kZNJEUBaZVl6gllSCu4ucc50iRqAjQAQYUPZM05Ioqc-kevCZjAhptzzcDJ5cDfKYRcysmeu5CvGHMr2vAeOad&amp;sig=AOD64_1u7RZKlQ0OdT6sy6sdqc9vg1T73A&amp;q&amp;nis=4&amp;adurl&amp;ved=2ahUKEwi35qyio_eEAxX6ITQIHfQZC7wQ0Qx6BAgGEAE\">Google Gemini<\/a>.  These models, and pretty much every other large and serious application of AI nowadays, are based on the transformer architecture first introduced in <a href=\"https:\/\/arxiv.org\/pdf\/1706.03762.pdf\">Attention is All You Need<\/a>. So this post is about transformers, how and roughly why they work, and how to write your own.<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">The Architecture<\/h2>\n\n\n\n<p>The standard depiction of the transformer architecture comes from Figure 1 of <a href=\"https:\/\/arxiv.org\/pdf\/1706.03762.pdf\">Attention is All You Need<\/a>:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"382\" height=\"485\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-15_15-07.png\" alt=\"\" class=\"wp-image-2005\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-15_15-07.png 382w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-15_15-07-236x300.png 236w\" sizes=\"auto, (max-width: 382px) 100vw, 382px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>I initially found this image difficult to understand. Now that I&#8217;ve implemented my own transformer from scratch, I&#8217;m actually somewhat impressed by its concise expressiveness. That&#8217;s appropriate for a research paper where information density is key, but I think we can do better if we&#8217;re explaining the architecture to someone.<\/p>\n\n\n\n<p>Transformers at base operate on <em>tokens<\/em>, which are just discrete items. In effect, if you have 5 tokens, then you&#8217;re talking about having tokens 1, 2, 3, 4, and 5. Transformers first became popular when used on text, where tokens represents words like &#8220;potato&#8221; or fragments like &#8220;-ly&#8221;, but they have since been applied to <a href=\"https:\/\/arxiv.org\/abs\/2010.11929\">chunks of images<\/a>, <a href=\"https:\/\/arxiv.org\/abs\/2105.00335\">chunks of sound<\/a>, <a href=\"https:\/\/arxiv.org\/pdf\/2311.13502.pdf\">simple bits<\/a>, and even <a href=\"https:\/\/arxiv.org\/abs\/2106.01345\">discrete states, actions, and rewards<\/a>. I think words are the most intuitive to work with, so let&#8217;s go with that. <\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"630\" height=\"78\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-15_15-16.png\" alt=\"\" class=\"wp-image-2006\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-15_15-16.png 630w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-15_15-16-300x37.png 300w\" sizes=\"auto, (max-width: 630px) 100vw, 630px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>Transformers predict the next token some given all of the tokens that came before:<\/p>\n\n\n\n<p>\\[P(x_{t} \\mid x_{t-1}, x_{t-2}, \\ldots)\\]\n\n\n\n<p>For example, if we started a sentence with &#8220;the cat sat&#8221;, we might want it to produce a higher likelihood for &#8220;on&#8221; than &#8220;potato&#8221;. Conceptually, this looks like:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"322\" height=\"301\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image.png\" alt=\"\" class=\"wp-image-2011\" style=\"width:242px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image.png 322w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-300x280.png 300w\" sizes=\"auto, (max-width: 322px) 100vw, 322px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>The output probability distribution is just a Categorical distribution over the \\(n\\) possible tokens, which we can easily produce with a softmax layer. <\/p>\n\n\n\n<p>You&#8217;ll notice that my model goes top to bottom whereas academia for some unfathomable reason usually depicts deep neural nets bottom to top. We read top to bottom so I&#8217;m sticking with that.<\/p>\n\n\n\n<p>We could use a model like this to recursively generate a full sentence:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"657\" height=\"359\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-1.png\" alt=\"\" class=\"wp-image-2012\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-1.png 657w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-1-300x164.png 300w\" sizes=\"auto, (max-width: 657px) 100vw, 657px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>So far we&#8217;ve come up with a standard autoregressive model. Those have been around forever. What makes transformers special is that they split the inputs from the outputs such that the inputs need only be encoded once, that they use attention to help make the model resilient to how the tokens are ordered, and that they solve the problem of vanishing gradients.<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">Separate Inputs and Outputs<\/h3>\n\n\n\n<p>The first weird thing about the transformer architecture figure is that it has a left side and a right side. The left side receives &#8220;inputs&#8221; and the right side receives &#8220;outputs&#8221;. What is going on here?<\/p>\n\n\n\n<p>If we are trying to generate a sentence that starts with &#8220;the cat sat&#8221;, then &#8220;the cat sat&#8221; is the inputs and the subsequent tokens are the outputs. During training we&#8217;d know what the subsequent tokens are (our training set would split sentences into (input, output) pairs, such as randomly in the middle or via question\/answer), and during inference we&#8217;d sample the outputs sequentially.<\/p>\n\n\n\n<p>Conceptually, we&#8217;re now thinking about an architecture along these lines:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"643\" height=\"384\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-2.png\" alt=\"\" class=\"wp-image-2017\" style=\"width:533px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-2.png 643w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-2-300x179.png 300w\" sizes=\"auto, (max-width: 643px) 100vw, 643px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>Is this weird? Yes. Why do they do it? Because it&#8217;s more efficient during training, and during inference you only need to run the input head once.<\/p>\n\n\n\n<p>During training, we know all of the tokens and so we just stick a loss function on this to maximize the likelihood of the correct output token:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"621\" height=\"417\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-3.png\" alt=\"\" class=\"wp-image-2020\" style=\"width:531px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-3.png 621w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-3-300x201.png 300w\" sizes=\"auto, (max-width: 621px) 100vw, 621px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>During inference, we run the encoder once and start by running the model with an empty output:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"627\" height=\"381\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-5.png\" alt=\"\" class=\"wp-image-2022\" style=\"width:484px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-5.png 627w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-5-300x182.png 300w\" sizes=\"auto, (max-width: 627px) 100vw, 627px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>We only use the probability distribution over the first token, sample from it, and append the sampled token to our output. If we haven&#8217;t finished out sentence yet, we continue on:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"638\" height=\"380\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-7.png\" alt=\"\" class=\"wp-image-2024\" style=\"width:516px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-7.png 638w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-7-300x179.png 300w\" sizes=\"auto, (max-width: 638px) 100vw, 638px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>We only use the probabilistic distribution over the next token, sample from it, etc. This iterative process is auto-regressive, just like we were talking about before. <span id=\"su_tooltip_69e9c29573f70_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9c29573f70\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">(aside 1)<\/mark><\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9c29573f70\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">This sampling typically ends when we reach a special token dedicated to stopping, called the <em>end token<\/em>. Its just one of the n tokens just like any other except it has this special significance and doesn't get printed out when you show the sentence to the user.<\/span><\/span><span id=\"su_tooltip_69e9c29573f70_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span> <span id=\"su_tooltip_69e9c295740fc_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9c295740fc\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">(aside 2)<\/mark><\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9c295740fc\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">Sometimes we can literally pass nothing into the output head. However, transformers often have a fixed size and we have to fill in those output values with something. In that case we can use special padding tokens and masking to prevent the model from using that information. More on that later.<\/span><\/span><span id=\"su_tooltip_69e9c295740fc_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span>\n\n\n\n<p>The reason I think this approach is awkward is because you&#8217;d typically want the representation of \\(P(x_{t} \\mid x_{t-1}, x_{t-2}, \\ldots)\\) to not care about whether the previous tokens were original inputs or not. That is, we want \\(P(\\text{the} | \\text{the cat sat on})\\) to be the same irrespective of whether &#8220;the cat sat on&#8221; are all inputs or whether &#8220;the cat sat&#8221; is the input and we just sampled &#8220;on&#8221;. In a transformer, they are different.<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">Robust to Token Ordering<\/h3>\n\n\n\n<p>The real problem that the transformer architecture solves is a form of robustness to token ordering. Let&#8217;s consider the following two input \/ output pairs:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>the cat sat | on the mat<\/li>\n\n\n\n<li>the yellow cat sat | on the mat<\/li>\n<\/ul>\n\n\n\n<p>They are the same except for the extra &#8220;yellow&#8221; token in the second sentence. <\/p>\n\n\n\n<p>If our model was made up of <span id=\"su_tooltip_69e9c29574175_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9c29574175\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">simple feed-forward layers<\/mark><\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9c29574175\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">By which I mean affine followed by your activation function of choice.<\/span><\/span><span id=\"su_tooltip_69e9c29574175_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span>, that would present a problem:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"632\" height=\"402\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-8.png\" alt=\"\" class=\"wp-image-2032\" style=\"width:572px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-8.png 632w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-8-300x191.png 300w\" sizes=\"auto, (max-width: 632px) 100vw, 632px\" \/><\/figure>\n<\/div>\n\n\n<p> <\/p>\n\n\n\n<p>A feedforward layer contains an affine transform \\(\\boldsymbol{x}&#8217; \\gets \\boldsymbol{A} \\boldsymbol{x} + \\boldsymbol{b}\\) that learns a different mapping for every input. We can even write it out for three inputs:<\/p>\n\n\n\n<p>\\[\\begin{matrix}x&#8217;_1 &amp;\\gets A_{11}x_1 + A_{12}x_2 + A_{13}x_3 + b_1 \\\\ x&#8217;_2 &amp;\\gets A_{21}x_1 + A_{22}x_2 + A_{23}x_3 + b_2 \\\\ x&#8217;_3 &amp;\\gets A_{31}x_1 + A_{32}x_2 + A_{33}x_3 + b_3\\end{matrix}\\]\n\n\n\n<p>If we learn something about &#8220;cat&#8221; being in the second position, we&#8217;d have to learn it all over again to handle the case where &#8220;cat&#8221; is in the third position.<\/p>\n\n\n\n<p>Transformers are robust to this issue because of their use of <em>attention<\/em>. Put very simply, attention allows the neural network to learn when particular tokens are important in a position-independent way, such that they can be focused on when needed.<\/p>\n\n\n\n<p>Transformers use <em>scaled dot product attention<\/em>. Here, we input three things:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>a query \\(\\boldsymbol{q}\\)<\/li>\n\n\n\n<li>a key \\(\\boldsymbol{k}\\)<\/li>\n\n\n\n<li>a value \\(\\boldsymbol{v}\\)<\/li>\n<\/ul>\n\n\n\n<p>Each of these are vector embeddings, but they can be thought of as:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>query &#8211; a representation of what we are asking for<\/li>\n\n\n\n<li>key &#8211; how well the current token we&#8217;re looking at reflects what we&#8217;re asking for<\/li>\n\n\n\n<li>value &#8211; how important it is that we get something that matches  what we&#8217;re asking for<\/li>\n<\/ul>\n\n\n\n<p>For example, if we have &#8220;the cat sat on the ____&#8221;, and we&#8217;re looking to fill in that last blank, it might be useful for the model to have learned a query for representing things that sit, and to value the result of that query a lot when we need to fill in a word for something that sits.<\/p>\n\n\n\n<p>We take the dot product of the query and the key to measure how well they match: \\(\\boldsymbol{q}^T \\boldsymbol{k}\\). Each token has a key, so we end up with a measure of how well each of them matches:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"677\" height=\"320\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-03.png\" alt=\"\" class=\"wp-image-2110\" style=\"width:577px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-03.png 677w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-03-300x142.png 300w\" sizes=\"auto, (max-width: 677px) 100vw, 677px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>Those measures can take on any values. Taking the <a href=\"https:\/\/en.wikipedia.org\/wiki\/Softmax_function\">softmax<\/a> turns these measures into values that lie in [0,1], preserving relative size: <span id=\"su_tooltip_69e9c295741eb_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9c295741eb\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">(aside)<\/mark><\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9c295741eb\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">The transformer version scales the values before passing them through softmax, but that isn't particularly important.<\/span><\/span><span id=\"su_tooltip_69e9c295741eb_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"632\" height=\"390\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-11.png\" alt=\"\" class=\"wp-image-2065\" style=\"width:526px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-11.png 632w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-11-300x185.png 300w\" sizes=\"auto, (max-width: 632px) 100vw, 632px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>In this example, our activation vector has a large value (closer to 1) for &#8220;cat&#8221;.<\/p>\n\n\n\n<p>Finally, we can take the dot product of our activation vector \\(\\boldsymbol{\\alpha}\\) with our value vector to get the overall attention value: \\(\\boldsymbol{\\alpha} \\cdot  \\boldsymbol{v}\\).<\/p>\n\n\n\n<p>Notice that where cat was in the list of tokens didn&#8217;t matter all to much. If we shift it around, but give it the same key, it will continue to produce the same activation value. Then, as long as the value is high anywhere that activation is active, we&#8217;ll get a large output.<\/p>\n\n\n\n<p>Putting this together, our attention function for a single query \\(\\boldsymbol{q}\\) is:<\/p>\n\n\n\n<p>\\[ \\texttt{attention}(\\boldsymbol{q}, \\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{k}^{(n)}, \\boldsymbol{v}) =  \\texttt{softmax}\\left(\\boldsymbol{q}^T [\\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{k}^{(n)}]\\right) \\cdot  \\boldsymbol{v} \\]\n\n\n\n<p>We can combine the keys together into a single matrix \\(\\boldsymbol{K}\\), which simplifies things to:<\/p>\n\n\n\n<p>\\[ \\texttt{attention}(\\boldsymbol{q}, \\boldsymbol{K}, \\boldsymbol{v}) =  \\texttt{softmax}\\left(\\boldsymbol{q}^T \\boldsymbol{K}\\right) \\cdot \\boldsymbol{v} \\]\n\n\n\n<p>We&#8217;re going to want a bunch of queries, not just one. That&#8217;s equivalent to expanding our query and value vectors into matrices:<\/p>\n\n\n\n<p>\\[ \\texttt{attention}(\\boldsymbol{Q}, \\boldsymbol{K}, \\boldsymbol{V}) =  \\texttt{softmax}\\left(\\boldsymbol{Q}^T \\boldsymbol{K}\\right) \\boldsymbol{V} \\]\n\n\n\n<p>We&#8217;ve basically recovered the attention function given in the paper. I just has an additional scalar term that helps keep the logits passed into softmax smaller in magnitude:<\/p>\n\n\n\n<p>\\[\\texttt{attention}(\\boldsymbol{Q}, \\boldsymbol{K}, \\boldsymbol{V}) = \\texttt{softmax}\\left(\\frac{\\boldsymbol{Q}\\boldsymbol{K}^T}{\\sqrt{d_k}}\\right) \\boldsymbol{V}\\]\n\n\n\n<p>where \\(d_k\\) is the dimension of the keys &#8212; i.e. how many features long the embeddings are.<\/p>\n\n\n\n<p>The output is a matrix that has larger entries where our queries matched tokens and our value was large. That is, the transformer learned to ask for something and extract it out if it exists.<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">Robust to Vanishing Gradients<\/h3>\n\n\n\n<p>The other important problem that transformers solve is the <em>vanishing gradient problem<\/em>. The previous alternative to the transformer, the <a href=\"https:\/\/en.wikipedia.org\/wiki\/Recurrent_neural_network\">recurrent neural network<\/a> (RNN), tends to suffer from this issue.<\/p>\n\n\n\n<p>A recurrent neural network represents a sequence by taking as input the current token and a latent state:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"816\" height=\"328\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-30.png\" alt=\"\" class=\"wp-image-2184\" style=\"width:688px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-30.png 816w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-30-300x121.png 300w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-30-768x309.png 768w\" sizes=\"auto, (max-width: 816px) 100vw, 816px\" \/><\/figure>\n<\/div>\n\n\n<p>This state is referred to as the RNN&#8217;s <em>memory<\/em>. <\/p>\n\n\n\n<p>The vanishing gradients problem arises when we try to propagate gradient information far back in time (where far can just be a few tokens). If we want the network to learn to associate &#8220;mat&#8221; with &#8220;cat&#8221;, then we need to propagate through 4 instances of the RNN:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"819\" height=\"352\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-34.png\" alt=\"\" class=\"wp-image-2185\" style=\"width:689px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-34.png 819w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-34-300x129.png 300w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-34-768x330.png 768w\" sizes=\"auto, (max-width: 819px) 100vw, 819px\" \/><\/figure>\n<\/div>\n\n\n<p>Simply put, in this example we&#8217;re taking the gradient of RNN(RNN(RNN(RNN(&#8220;cat&#8221;, state)). The chain rule tells us that the derivative of \\(f(f(f(f(x))))\\) is:<\/p>\n\n\n\n<p>\\[f'(x) \\> f'(f(x)) \\> f'(f(f(x))) \\> f'(f(f(f(x))))\\]\n\n\n\n<p>If \\(f'(x)\\) is <span id=\"su_tooltip_69e9c29574253_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9c29574253\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">smaller than 1<\/mark><\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9c29574253\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">If its larger than 1, then we have a worse problem - exploding gradients. Our gradient will get ridiculously large and cause our optimization steps to get too big.<\/span><\/span><span id=\"su_tooltip_69e9c29574253_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span>, then we very quickly drive that gradient signal toward zero as we try to propagate it back further.  The gradient vanishes!<\/p>\n\n\n\n<p>Transformers solve this problem by having all of the inputs be mixed in the attention layers. This allows the gradient to readily flow across. <\/p>\n\n\n\n<p>It also mitigates the problem using <em><a href=\"https:\/\/en.wikipedia.org\/wiki\/Residual_neural_network\">residual connections<\/a><\/em>. These are the &#8220;skip connections&#8221; that you&#8217;ll see below. They provide a way for the gradient to flow up the network unimpeded, so it doesn&#8217;t decrease if it has to travel through a bunch of transformer layers.<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">Building a Transformer<\/h2>\n\n\n\n<p>Let&#8217;s use what we&#8217;ve learned and build a transformer, setting aside the fact that we haven&#8217;t covered multiheaded attention just yet. We&#8217;re going to construct the following input head:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"279\" height=\"415\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-12.png\" alt=\"\" class=\"wp-image-2095\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-12.png 279w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/image-12-202x300.png 202w\" sizes=\"auto, (max-width: 279px) 100vw, 279px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>This is analogous to a single &#8220;input encoder layer&#8221; in the transformer architecture. Note that the real deal uses multiheaded attention and has dropout after each layer norm.<\/p>\n\n\n\n<p>The trainable parameters in our input head are the three projection matrices \\(\\boldsymbol{W}^Q\\), \\(\\boldsymbol{W}^K\\), and \\(\\boldsymbol{W}^V\\), as well as the learnable params in the feedforward layer. <a href=\"https:\/\/arxiv.org\/abs\/1607.06450\">Layer normalization<\/a> is simply normalization along each input vector rather than along the batch dimension, which keeps things nicely scaled. <\/p>\n\n\n\n<p>I&#8217;m going to implement this input head in <a href=\"https:\/\/julialang.org\/\">Julia<\/a>, using <a href=\"https:\/\/github.com\/FluxML\/Flux.jl\">Flux.jl<\/a>. The code is remarkably straightforward and self-explanatory:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia line-numbers\">struct InputHead\n    q_proj::Dense\n    k_proj::Dense\n    v_proj::Dense\n    norm_1::LayerNorm\n    affine1::Dense\n    affine2::Dense\n    norm_2::LayerNorm\nend\n\nFlux.@functor InputHead\n\nfunction (m::InputHead)(\n    X::Array{Float32, 3}) # [dim \u00d7 ntokens \u00d7 batch_size]\n\n    dim = size(X, 1)\n\n    # scaled dot product attention\n    Q = m.q_proj(X) # [dim \u00d7 ntokens \u00d7 batch_size]\n    K = m.k_proj(X) # [dim \u00d7 ntokens \u00d7 batch_size]\n    V = m.v_proj(X) # [dim \u00d7 ntokens \u00d7 batch_size]\n    \u03b1 = softmax(Q*K'.\/\u221aFloat32(dim), dims=1)\n    X\u2032 = \u03b1*V\n\n    # add norm\n    X = m.norm_1(X + X\u2032)\n\n    # feedforward\n    X\u2032 = m.affine2(relu(m.affine1(X)))\n    \n    # add norm\n    X = m.norm_1(X + X\u2032)\n\n    return X\nend<\/code><\/pre>\n\n\n\n<p>Once we have this, building the rest of the simplified transformer is pretty easy. Let&#8217;s similarly define an output head:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"376\" height=\"628\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-16_18-37.png\" alt=\"\" class=\"wp-image-2106\" style=\"width:304px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-16_18-37.png 376w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-16_18-37-180x300.png 180w\" sizes=\"auto, (max-width: 376px) 100vw, 376px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>The output head is basically the same, except it has an additional attention + layer normalization section that receives its query from within the head but its keys and values from the input head. This is sometimes called <em>cross attention<\/em>, and allows the outputs to receive information from the inputs. Implementing this in Julia would be pretty straightforward, so I&#8217;m not covering it here.<\/p>\n\n\n\n<p>In order to fully flesh out a basic transformer, we just have to define what happens at the outputs and inputs. The output is pretty simple. Its just a feedforward layer followed by a softmax:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"672\" height=\"393\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-23.png\" alt=\"\" class=\"wp-image-2122\" style=\"width:572px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-23.png 672w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-23-300x175.png 300w\" sizes=\"auto, (max-width: 672px) 100vw, 672px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>The feedforward layer both provides a chance for everything to get properly mixed (since its fully connected), but more importantly, changes the tensor dimension from the embedding size to the number tokens. The softmax then ensures that the output represents probabilities.<\/p>\n\n\n\n<p>The inputs, by which I mean all of the tokens, have to be turned into embedding matrices. Remember that there&#8217;s only a finite set of tokens. We associate a vector embedding with each token, which we re-use every time the token shows up. This embedding can be learned, in which case its the same as having our token be a one-hot vector \\(x_\\text{one-hot}\\) and learning some matrix \\(m\\times n\\) matrix \\(E\\) where \\(m\\) is the embedding dimension and \\(n\\) is the number of tokens:<\/p>\n\n\n\n<p>\\[z = E x_\\text{one-hot}\\]\n\n\n\n<p>We then add a <em>positional encoding<\/em> to each embedding. Wait, I thought we wanted the network to be robust to position? Well yeah, it is. But knowing where something shows up, and whether it shows up before or after something else is also important. So to give the model a way to determine that, we introduce a sort of signature or pattern to each embedding. The original paper uses a sinusoidal pattern:<\/p>\n\n\n\n<p>\\[z_\\text{pos encode}(i)_j = \\begin{cases} \\sin\\left(i \/ 10000^{2j \/ j_\\text{max}}\\right) &amp; \\text{if } j \\text{ is even} \\\\ \\cos\\left(i \/ 10000^{2j \/ j_\\text{max}}\\right) &amp; \\text{otherwise} \\end{cases}\\]\n\n\n\n<p>for the \\(j\\)th entry of the position encoding at position \\(i\\).<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"685\" height=\"317\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-34.png\" alt=\"\" class=\"wp-image-2137\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-34.png 685w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/03\/2024-03-30_14-34-300x139.png 300w\" sizes=\"auto, (max-width: 685px) 100vw, 685px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>What I&#8217;m not showing here is that the original paper again uses dropout after the position encoding is added in.<\/p>\n\n\n\n<p>Flux <a href=\"https:\/\/fluxml.ai\/Flux.jl\/stable\/models\/layers\/#Flux.Embedding\">already supports token embeddings<\/a>, so let&#8217;s just use that. We can just generate our position encodings in advance:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function generate_position_encoding(dim::Int, max_sequence_length::Int)\n    pos_enc = zeros(Float32, dim, max_sequence_length)\n\n    for j in 0:2:(dim-1)\n        denominator::Float32 = 10000.0^(2.0*(j \u00f7 2)\/dim)\n        for i in 1:max_sequence_length\n            pos_enc[j+1, i] = sin(i \/ denominator)\n        end\n    end\n\n    for j in 1:2:(dim-1)\n        denominator::Float32 = 10000.0^(2.0*(j \u00f7 2)\/dim)\n        for i in 1:max_sequence_length\n            pos_enc[j+1, i] = cos(i \/ denominator)\n        end\n    end\n\n    return pos_enc\nend<\/code><\/pre>\n\n\n\n<p>The overall input encoder is then simply:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code class=\"\">struct InputEncoder\n    embedding::Embedding # [vocab_size => dim]\n    position_encoding::Matrix{Float32} # [dim \u00d7 n_tokens]\n    dropout::Dropout\nend\n\nFlux.@functor InputEncoder\n\nfunction (m::InputEncoder)(tokens::Matrix{Int}) # [n_tokens, batch_size]\n    X = m.embedding(tokens) # [dim \u00d7 n_tokens \u00d7 batch_size]\n    X = X .+ m.position_encoding\n    return m.dropout(X)\nend<\/code><\/pre>\n\n\n\n<p>When it gets down to it, coding this stuff up really is just like <a href=\"http:\/\/cs231n.stanford.edu\/slides\/2016\/winter1516_lecture5.pdf#page=16\">stacking a bunch of lego bricks<\/a>.<\/p>\n\n\n\n<p>We can put it all together to get a super-simple transformer:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code class=\"\">struct Transformer\n    input_encoder::InputEncoder\n    trans_enc::TransformerEncoderLayer\n    trans_dec::TransformerDecoderLayer\n    linear::Dense # [vocab_size \u00d7 dim]\nend\n\nFlux.@functor Transformer\n\nfunction Transformer(vocab_size::Int, dim::Int, n_tokens::Int;\n                   hidden_dim_scale::Int = 4,\n                   init = Flux.glorot_uniform,\n                   dropout_prob = 0.0)\n    input_encoder = InputEncoder(\n        Flux.Embedding(vocab_size => dim),\n        generate_position_encoding(dim, n_tokens),\n        Dropout(dropout_prob))\n    trans_enc = TransformerEncoderLayer(dim,\n        hidden_dim_scale=hidden_dim_scale,\n        bias=true, init=init, dropout_prob=dropout_prob)\n    trans_dec = TransformerDecoderLayer(dim,\n        hidden_dim_scale=hidden_dim_scale,\n        bias=true, init=init, dropout_prob=dropout_prob)\n    linear = Dense(dim => vocab_size, bias=true, init=init)\n    return Transformer(input_encoder, trans_enc, trans_dec, linear)\nend\n\nfunction (m::Transformer)(\n    input_tokens::Matrix{Int},\n    output_tokens::Matrix{Int}) # [n_tokens, batch_size]\n\n    X_in = m.input_encoder(input_tokens) # [dim \u00d7 n_tokens \u00d7 batch_size]\n    X_out = m.input_encoder(output_tokens) # [dim \u00d7 n_tokens \u00d7 batch_size]\n    E = m.trans_enc(X_in)\n    X = m.trans_dec(X_out, E)\n    logits = m.linear(X) # [vocab_size \u00d7 n_tokens \u00d7 batch_size]\n    return logits\nend<\/code><\/pre>\n\n\n\n<p>Note that this code doesn&#8217;t run a softmax because we can directly use the <a href=\"https:\/\/fluxml.ai\/Flux.jl\/stable\/models\/losses\/#Flux.Losses.logitcrossentropy\">crossentropy loss on the logits<\/a>, which is often more accurate.<\/p>\n\n\n\n<p>Running this model requires generating datasets of tokens, where each token is an integer. We pass in our batch of token inputs and our batch of token outputs, and see what logits we get.<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">Masking<\/h2>\n\n\n\n<p>We&#8217;ve got a basic transformer that will take input tokens and produce output logits. We can train it to the <span id=\"su_tooltip_69e9c295742d1_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9c295742d1\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">maximize the likelihood of the expected future tokens.<\/mark><\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9c295742d1\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">Which is the same as minimizing the negative log likelihood.<\/span><\/span><span id=\"su_tooltip_69e9c295742d1_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span> We&#8217;ll probably be able to perform pretty well on the training set, but if we go and try to actually generate sentences, it won&#8217;t do all that well.<\/p>\n\n\n\n<p>Why? Because the current setup allows the neural network to attend to future information.<\/p>\n\n\n\n<p>Recall that we&#8217;re trying to predict the next token given the tokens that came before:<\/p>\n\n\n\n<p>\\[P(x_{t} \\mid x_{t-1}, x_{t-2}, \\ldots)\\]\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"706\" height=\"253\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_19-08.png\" alt=\"\" class=\"wp-image-2147\" style=\"width:576px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_19-08.png 706w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_19-08-300x108.png 300w\" sizes=\"auto, (max-width: 706px) 100vw, 706px\" \/><\/figure>\n<\/div>\n\n\n<p>The way our output head is currently constructed, we&#8217;re allowing tokens there to attend all of the tokens, which includes future output tokens.<\/p>\n\n\n\n<p>Let&#8217;s revisit our attention function for a single query \\(\\boldsymbol{q}\\):<\/p>\n\n\n\n<p>\\[ \\texttt{attention}(\\boldsymbol{q}, \\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{k}^{(n)}, \\boldsymbol{v}) =  \\texttt{softmax}\\left(\\boldsymbol{q}^T [\\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{k}^{(n)}]\\right) \\cdot  \\boldsymbol{v} \\]\n\n\n\n<p>We want to modify the attention function such that we do not consider the keys for tokens beyond a certain index, say index \\(t\\). That means we want <a href=\"https:\/\/en.wikipedia.org\/wiki\/Softmax_function\">softmax<\/a> to disregard those tokens. To achieve that, we need their values to be \\(-\\infty\\):<\/p>\n\n\n\n<p>\\[ \\underset{\\leq t}{\\texttt{attention}}(\\boldsymbol{q}, \\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{k}^{(n)}, \\boldsymbol{v}) =  \\texttt{softmax}\\left([\\boldsymbol{q}^T \\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{q}^T \\boldsymbol{k}^{(t)},  -\\infty, \\ldots, -\\infty]\\right) \\cdot  \\boldsymbol{v} \\]\n\n\n\n<p>One easy way to do this is to pass in an additional <em>mask<\/em> vector that is zero for \\(i \\leq t\\) and negative infinity otherwise:<\/p>\n\n\n\n<p>\\[ \\begin{aligned}\\underset{\\leq t}{\\texttt{attention}}(\\boldsymbol{q}, \\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{k}^{(n)}, \\boldsymbol{v}) =  \\texttt{softmax}( &amp; [\\boldsymbol{q}^T \\boldsymbol{k}^{(1)}, \\ldots, \\boldsymbol{q}^T \\boldsymbol{k}^{(t)},  \\boldsymbol{q}^T \\boldsymbol{k}^{(t+1)}, \\ldots, \\boldsymbol{q}^T \\boldsymbol{k}^{(n)}] + \\\\ &amp; [0, \\ldots, 0, -\\infty, \\ldots, -\\infty]) \\cdot  \\boldsymbol{v}\\end{aligned} \\]\n\n\n\n<p>or<\/p>\n\n\n\n<p>\\[ \\texttt{attention}(\\boldsymbol{q}, \\boldsymbol{K}, \\boldsymbol{v}, \\boldsymbol{m}) =  \\texttt{softmax}( \\boldsymbol{q}^T \\boldsymbol{K} +  \\boldsymbol{m}) \\cdot  \\boldsymbol{v} \\]\n\n\n\n<p>We don&#8217;t use just one query, we use a bunch. If you&#8217;re <span id=\"su_tooltip_69e9c29574340_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9c29574340\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">paying attention<\/mark><\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9c29574340\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">Pun intended.<\/span><\/span><span id=\"su_tooltip_69e9c29574340_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span>, you noticed that we actually get one query per output-head token because our query matrix is:<\/p>\n\n\n\n<p>\\[\\boldsymbol{Q} = \\boldsymbol{X}\\cdot \\boldsymbol{W}^Q\\]\n\n\n\n<p>So we want \\(\\boldsymbol{q}^{(1)}\\) to use \\(t = 1\\), \\(\\boldsymbol{q}^{(2)}\\) to use \\(t = 2\\), etc. This means that when we move to attention with multiple queries, we get:<\/p>\n\n\n\n<p>\\[ \\texttt{attention}(\\boldsymbol{Q}, \\boldsymbol{K}, \\boldsymbol{V}, \\boldsymbol{M}) =  \\texttt{softmax}\\left( \\frac{\\boldsymbol{Q} \\boldsymbol{K}^T}{\\sqrt{d_k}} +  \\boldsymbol{M}\\right)  \\boldsymbol{V} \\]\n\n\n\n<p>where \\(\\boldsymbol{M}\\) is a lookahead mask with the upper right triangle set to negative infinity:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"182\" height=\"162\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_19-41.png\" alt=\"\" class=\"wp-image-2170\"\/><\/figure>\n<\/div>\n\n\n<p>Check out <a href=\"https:\/\/gmongaras.medium.com\/how-do-self-attention-masks-work-72ed9382510f\">this blog post for some additional nice coverage of self-attention masks.<\/a><\/p>\n\n\n\n<h2 class=\"wp-block-heading\">Are We Done Yet?<\/h2>\n\n\n\n<p>Alright, we&#8217;ve got our baby transformer and we&#8217;ve added lookahead attention. Are we done?<\/p>\n\n\n\n<p>Well, sort of. This is the minimum amount of stuff necessary to appropriately learn something that can call itself a transformer and that will probably sort-of work. All of the concepts are there. A real transformer will be the same, just bigger.<\/p>\n\n\n\n<p>First, it will use multi-headed rather than single-headed attention. This just means that it does single-headed attention multiple times in parallel:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"716\" height=\"301\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_19-58.png\" alt=\"\" class=\"wp-image-2176\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_19-58.png 716w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_19-58-300x126.png 300w\" sizes=\"auto, (max-width: 716px) 100vw, 716px\" \/><\/figure>\n<\/div>\n\n\n<p>We could do that by running a \\(\\texttt{for}\\) loop over multiple attention heads, but in practice it can be implemented by splitting our tensor into a new dimension according to the number of heads:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function (mha::MultiHeadAttention)(\n    Q_in::Array{Float32, 3}, # [dim \u00d7 ntokens \u00d7 batch_size]\n    K_in::Array{Float32, 3}, # [dim \u00d7 ntokens \u00d7 batch_size]\n    V_in::Array{Float32, 3}, # [dim \u00d7 ntokens \u00d7 batch_size]\n    mask::Array{Float32, 2}) # [ntokens \u00d7 ntokens]\n\n    dims = size(Q_in)\n    dim = dims[1]\n\n    # All matrices end up being [dim \u00d7 ntokens \u00d7 batch_size]\n    Q = mha.q_proj(Q_in) # [dim \u00d7 ntokens \u00d7 batch_size]\n    K = mha.k_proj(K_in) # [dim \u00d7 ntokens \u00d7 batch_size]\n    V = mha.v_proj(V_in) # [dim \u00d7 ntokens \u00d7 batch_size]\n\n    # Reshape to # [dim\u00f7nheads \u00d7 nheads \u00d7 ntokens \u00d7 batch_size]\n    Q = reshape(Q, dim \u00f7 mha.nheads, mha.nheads, dims[2:end]...)\n    K = reshape(K, dim \u00f7 mha.nheads, mha.nheads, dims[2:end]...)\n    V = reshape(V, dim \u00f7 mha.nheads, mha.nheads, dims[2:end]...)\n\n    # We're going to use batched_mul, which operates on the first 2 dimensions.\n    # We want Q, K, and V to act as `nheads` separate attention heads, so we \n    # need to move the 'nheads' dimension out of the first 2 dimensions.\n    Kp = permutedims(K, (3, 1, 2, 4)) # This effectively takes care of the transpose too\n    Qp = permutedims(Q, (1, 3, 2, 4)) .\/ \u221aFloat32(dim)\n    logits = batched_mul(Kp, Qp) # [ntokens \u00d7 ntokens \u00d7 nheads \u00d7 batch_size]\n\n    # Apply the mask\n    logits = logits .+ mask # [ntokens \u00d7 ntokens \u00d7 nheads \u00d7 batch_size]\n\n    # Compute the activations\n    \u03b1 = softmax(logits, dims=1) # [ntokens \u00d7 ntokens \u00d7 nheads \u00d7 batch_size]\n\n    # Run dropout on the activations\n    \u03b1 = mha.dropout(\u03b1) # [ntokens \u00d7 ntokens \u00d7 nheads \u00d7 batch_size]\n\n    # Multiply by V, again with a batched_mul\n    Vp = permutedims(V, (1, 3, 2, 4))\n    X = batched_mul(Vp, \u03b1)\n    X = permutedims(X, (1, 3, 2, 4)) # [dim\u00f7nheads \u00d7 nheads \u00d7 ntokens \u00d7 batch_size]\n    X = reshape(X, :, size(X)[3:end]...) # [dim \u00d7 ntokens \u00d7 batch_size]\n\n    # Compute the outward projection\n    return mha.out_proj(X) # [dim \u00d7 ntokens \u00d7 batch_size]\nend<\/code><\/pre>\n\n\n\n<p>Note that in this code, the inputs are all X in the first multi-headed attention layer of the output head, whereas in the second one (called the cross attention layer), the queries Q_in are set to X whereas the keys and values are set to the output of the input head.<\/p>\n\n\n\n<p>The other thing that a real transformer has is multiple transformer layers. That is, we repeat what we currently have as our input head and output head multiple times:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"639\" height=\"392\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-11.png\" alt=\"\" class=\"wp-image-2181\" style=\"width:547px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-11.png 639w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/04\/2024-03-31_20-11-300x184.png 300w\" sizes=\"auto, (max-width: 639px) 100vw, 639px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>The <a href=\"https:\/\/arxiv.org\/pdf\/1706.03762.pdf\">Attention is All You Need<\/a> paper used 8 parallel attention layers (attention heads) and 6 identical layers each in the encoder (input head) and decoder (output head).<\/p>\n\n\n\n<h1 class=\"wp-block-heading\">Conclusion<\/h1>\n\n\n\n<p>I hope this post helps demystify the transformer architecture. They are the workhorse of modern large-scale deep learning, and its worth familiarizing yourself with them if you&#8217;re going to be working in that area. While complicated at first glance, they actually arise from fairly straightforward principles and solve some fairly practical and understandable problems. The information here should be enough to get started using them, or even roll your own from scratch.<\/p>\n\n\n\n<p>Happy coding!<\/p>\n","protected":false},"excerpt":{"rendered":"<p>This month I decided to take a break from my sidescroller project and instead properly attend to transformers (pun intended). Unless you&#8217;ve been living under a rock, you&#8217;ve noticed the rapid advanced of AI in the last year and the advent of extremely large models like ChatGPT 4 and Google Gemini. These models, and pretty [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"closed","ping_status":"","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[11],"tags":[],"class_list":["post-2003","post","type-post","status-publish","format-standard","hentry","category-deep-learning"],"_links":{"self":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts\/2003","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/comments?post=2003"}],"version-history":[{"count":161,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts\/2003\/revisions"}],"predecessor-version":[{"id":2191,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts\/2003\/revisions\/2191"}],"wp:attachment":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/media?parent=2003"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/categories?post=2003"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/tags?post=2003"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}