{"id":2922,"date":"2024-12-01T18:35:22","date_gmt":"2024-12-01T18:35:22","guid":{"rendered":"https:\/\/timallanwheeler.com\/blog\/?p=2922"},"modified":"2024-12-01T18:35:22","modified_gmt":"2024-12-01T18:35:22","slug":"rollouts-on-the-gpu","status":"publish","type":"post","link":"https:\/\/timallanwheeler.com\/blog\/2024\/12\/01\/rollouts-on-the-gpu\/","title":{"rendered":"Rollouts on the GPU"},"content":{"rendered":"\n<p>Last month <a href=\"https:\/\/timallanwheeler.com\/blog\/2024\/11\/05\/tuning-a-sokoban-policy-net\/\">I wrote about<\/a> moving the Sokoban policy training code from CPU to GPU, yielding <span id=\"su_tooltip_69e9e430d98b9_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9e430d98b9\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">massive speedups<\/mark>.<\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9e430d98b9\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">Reduced training time by about 32x for the base case and training time scales much better with model size.<\/span><\/span><span id=\"su_tooltip_69e9e430d98b9_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span> That significantly shortened both training time and the time it takes to compute basic validation metrics. It has not, unfortunately, significantly changed how long it takes to run rollouts, and relatedly, how long it takes to run beam search.<\/p>\n\n\n\n<h1 class=\"wp-block-heading\">The Bottleneck<\/h1>\n\n\n\n<p>The training that I&#8217;ve done so far has all been with <a href=\"https:\/\/en.wikipedia.org\/wiki\/Teacher_forcing#:~:text=Teacher%20forcing%20is%20an%20algorithm,to%20the%20ground%2Dtruth%20sequence.\">teacher forcing<\/a>, which allows all inputs to be passed to the net at once:<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"960\" height=\"540\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-1.jpg\" alt=\"\" class=\"wp-image-2929\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-1.jpg 960w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-1-300x169.jpg 300w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-1-768x432.jpg 768w\" sizes=\"auto, (max-width: 960px) 100vw, 960px\" \/><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p>When we do a rollout, we can&#8217;t pass everything in at once. We start with our initial state and use the policy to discover where we end up:<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"960\" height=\"540\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-2.jpg\" alt=\"\" class=\"wp-image-2930\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-2.jpg 960w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-2-300x169.jpg 300w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/Untitled-presentation-2-768x432.jpg 768w\" sizes=\"auto, (max-width: 960px) 100vw, 960px\" \/><\/figure>\n\n\n\n<p>The problem is that the left-side of that image, the policy call, is happening on the GPU, but the right side, the state advancement, is happening on the CPU. If a rollout involves 62 player steps, then instead of one data transfer step like we have for training, we&#8217;re doing 61 transfers! Our bottleneck is all that back-and-forth communication:<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"962\" height=\"174\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_08-45.png\" alt=\"\" class=\"wp-image-2931\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_08-45.png 962w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_08-45-300x54.png 300w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_08-45-768x139.png 768w\" sizes=\"auto, (max-width: 962px) 100vw, 962px\" \/><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p>Let&#8217;s move everything to the GPU.<\/p>\n\n\n\n<h1 class=\"wp-block-heading\">CPU Code<\/h1>\n\n\n\n<p>So what is currently happening on the CPU?<\/p>\n\n\n\n<p>At every state, we are:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>Sampling an action for each board from the action logits<\/li>\n\n\n\n<li>Applying that action to each board to advance the state<\/li>\n<\/ol>\n\n\n\n<p>Sampling from the actions is pretty straightforward to run on the GPU. That&#8217;s the bread and butter of transformers and RL in general.<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\"># policy_logits are [a\u00d7s\u00d7b] (a=actions, s=sequence length, b = batch size)\npolicy_logits, nsteps_logits = policy(inputs)\n\n# Sample from the logits using the Gumbel-max trick\nsampled_actions = argmax(policy_logits .+ gumbel_noise, dims=1)<\/code><\/pre>\n\n\n\n<p>where we use <a href=\"https:\/\/timvieira.github.io\/blog\/post\/2014\/07\/31\/gumbel-max-trick\/\">the Gumbel-max trick<\/a> and the Gumble noise is sampled in advance and passed to the GPU like the other inputs:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">using Distributions.jl\ngumbel_noise = rand(Gumbel(0, 1), size(a, s, b))<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>Advancing the board states is more complicated. Here is the CPU method for a single state:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function maybe_move!(board::Board, dir::Direction)::Bool\n    \u25a1_player::TileValue=find_player_tile(board)\n    step_fore = get_step_fore(board, dir)\n\n    \u25a1 = \u25a1_player # where the player starts\n    \u25a9 = \u25a1 + step_fore # where the player potentially ends up\n\n    if is_set(board[\u25a9], WALL)\n        return false # We would be walking into a wall\n    end\n\n    if is_set(board[\u25a9], BOX)\n        # We would be walking into a box.\n        # This is only a legal move if we can push the box.\n        \u25f0 = \u25a9 + step_fore # where box ends up\n        if is_set(board[\u25f0],  WALL + BOX)\n            return false # We would be pushing the box into a box or wall\n        end\n\n        # Move the box\n        board[\u25a9] &amp;= ~BOX # Clear the box\n        board[\u25f0] |= BOX # Add the box\n    end\n\n    # At this point we have established this as a legal move.\n    # Finish by moving the player\n    board[\u25a1] &amp;= ~PLAYER # Clear the player\n    board[\u25a9] |= PLAYER # Add the player\n\n    return true\nend<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>There are many ways to represent board states. This representation is a simple <code>Matrix{UInt8}<\/code>, so an 8&#215;8 board is just an 8&#215;8 matrix. Each tile is a bitfield with components that can be set for whether that tile has\/is a wall, box, floor, or tile.<\/p>\n\n\n\n<p>Moving the player has 3 possible paths:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>successful step: the destination tile is empty and we just move the player to it<\/li>\n\n\n\n<li>successful push: the destination tile has a box, and the next one over is empty, so we move both the player and the box<\/li>\n\n\n\n<li>failed move: otherwise, this is an illegal move and the player stays where they are<\/li>\n<\/ul>\n\n\n\n<p>Moving this logic to the GPU has to preserve this flow, use the GPU&#8217;s representation of the board state, and handle a tensor&#8217;s worth of board states at a time.<\/p>\n\n\n\n<h1 class=\"wp-block-heading\">GPU Representation<\/h1>\n\n\n\n<p>The input to the policy is a tensor of size \\([h \\times w \\times f \\times s \\times b]\\), where 1 board is encoded as a sparse \\((h = \\text{height}) \\times (w=\\text{width}) \\times (f = \\text{num features} = 5)\\) <span id=\"su_tooltip_69e9e430d99ea_button\" class=\"su-tooltip-button su-tooltip-button-outline-yes\" aria-describedby=\"su_tooltip_69e9e430d99ea\" data-settings='{\"position\":\"top\",\"behavior\":\"hover\",\"hideDelay\":0}' tabindex=\"0\"><mark style=\"background-color:rgba(0, 0, 0, 0)\" class=\"has-inline-color has-vivid-cyan-blue-color\">tensor<\/mark>:<\/span><span style=\"display:none;z-index:100\" id=\"su_tooltip_69e9e430d99ea\" class=\"su-tooltip\" role=\"tooltip\"><span class=\"su-tooltip-inner su-tooltip-shadow-no\" style=\"z-index:100;background:#222222;color:#FFFFFF;font-size:16px;border-radius:5px;text-align:left;max-width:300px;line-height:1.25\"><span class=\"su-tooltip-title\"><\/span><span class=\"su-tooltip-content su-u-trim\">We don't really need floors as a feature, since it is just not-walls. I originally thought it would be useful for the model, and can try ablating it in the future to see.<\/span><\/span><span id=\"su_tooltip_69e9e430d99ea_arrow\" class=\"su-tooltip-arrow\" style=\"z-index:100;background:#222222\" data-popper-arrow><\/span><\/span> <\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"599\" height=\"394\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_09-33.png\" alt=\"\" class=\"wp-image-2942\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_09-33.png 599w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_09-33-300x197.png 300w\" sizes=\"auto, (max-width: 599px) 100vw, 599px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>and we have board sequences of length \\(s\\) and \\(b\\) sequences per batch of them:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"523\" height=\"302\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_09-27.png\" alt=\"\" class=\"wp-image-2940\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_09-27.png 523w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/11\/2024-11-24_09-27-300x173.png 300w\" sizes=\"auto, (max-width: 523px) 100vw, 523px\" \/><\/figure>\n<\/div>\n\n\n<p class=\"has-small-font-size\">I purposely chose 4-step boards here, but sequences can generally be much longer and of different lengths, and the first state in each sequence is the goal state.<\/p>\n\n\n\n<p>Our actions will be the \\([4\\times s \\times b]\\) actions tensor &#8212; one up\/down\/left\/right action per board state. <\/p>\n\n\n\n<h1 class=\"wp-block-heading\">Shifting Tensors<\/h1>\n\n\n\n<p>The first fundamental operation we&#8217;re going to need is to be able to check tile neighbors. That is, instead of doing this:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">\u25a1 = \u25a1_player # where the player starts\n\u25a9 = \u25a1 + step_fore # where the player potentially ends up<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>we&#8217;ll be <em>shifting<\/em> all tiles over and checking that instead:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"574\" height=\"166\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_08-35.png\" alt=\"\" class=\"wp-image-2955\" style=\"width:548px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_08-35.png 574w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_08-35-300x87.png 300w\" sizes=\"auto, (max-width: 574px) 100vw, 574px\" \/><\/figure>\n<\/div>\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">is_player_dests = shift_tensor(is_players, d_row=0, d_col=1)<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>The shift_tensor method method takes in a tensor and shifts it by the given number of rows and columns, padding in new values:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"615\" height=\"162\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_08-39.png\" alt=\"\" class=\"wp-image-2956\" style=\"width:575px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_08-39.png 615w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_08-39-300x79.png 300w\" sizes=\"auto, (max-width: 615px) 100vw, 615px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>We pass in the number of rows or columns to shift, figure out what that means in terms of padding, and  then leverage NNlib&#8217;s <a href=\"https:\/\/fluxml.ai\/NNlib.jl\/stable\/reference\/#NNlib.pad_constant\">pad_constant<\/a> method to give us a new tensor that we clamp to a new range:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function shift_tensor(\n    tensor::AbstractArray,\n    d_row::Integer,\n    d_col::Integer,\n    pad_value)\n\n    pad_up    = max( d_row, 0)\n    pad_down  = max(-d_row, 0)\n    pad_left  = max( d_col, 0)\n    pad_right = max(-d_col, 0)\n\n    tensor_padded = NNlib.pad_constant(\n        tensor,\n        (pad_up, pad_down, pad_left, pad_right, \n            (0 for i in 1:2*(ndims(tensor)-2))...),\n        pad_value)\n\n    dims = size(tensor_padded)\n    row_lo = 1 + pad_down\n    row_hi = dims[1] - pad_up\n    col_lo = 1 + pad_right\n    col_hi = dims[2] - pad_left\n\n    return tensor_padded[row_lo:row_hi, col_lo:col_hi,\n                         (Colon() for d in dims[3:end])...]\nend<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>This method works on tensors with varying numbers of dimensions, and always operates on the first two dimensions as the row and column dimensions.<\/p>\n\n\n\n<h1 class=\"wp-block-heading\">Taking Actions<\/h1>\n\n\n\n<p>If we know the player move, we can use the appropriate shift direction to get the &#8220;next tile over&#8221;. Our player moves can be reflected by the following row and column shift values:<\/p>\n\n\n\n<p class=\"has-text-align-center\">UP = (d_row=-1, d_col= 0)<br>LEFT = (d_row= 0, d_col=-1)<br>DOWN = (d_row=+1, d_col= 0)<br>RIGHT = (d_row= 0, d_col=+1)<\/p>\n\n\n\n<p>This lets us convert the CPU-movement code into a bunch of Boolean tensor operations:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function advance_boards(\n    inputs::AbstractArray{Bool}, # [h,w,f,s,b]\n    d_row::Integer,\n    d_col::Integer)\n\n    boxes  = inputs[:,:,DIM_BOX,   :,:]\n    player = inputs[:,:,DIM_PLAYER,:,:]\n    walls  = inputs[:,:,DIM_WALL,  :,:]\n\n    player_shifted = shift_tensor(player, d_row, d_col, false)\n    player_2_shift = shift_tensor(player_shifted, d_row, d_col, false)\n\n    # A move is valid if the player destination is empty\n    # or if its a box and the next space over is empty\n    not_box_or_wall = .!(boxes .| walls)\n\n    # 1 if it is a valid player destination tile for a basic player move\n    move_space_empty = player_shifted .&amp; not_box_or_wall\n\n    # 1 if the tile is a player destination tile containing a box\n    move_space_isbox = player_shifted .&amp; boxes\n\n    # 1 if the tile is a player destination tile whose next one over\n    # is a valid box push receptor\n    push_space_empty = player_shifted .&amp; shift_tensor(not_box_or_wall, -d_row, -d_col, false)\n\n    # 1 if it is a valid player move destination\n    move_mask = move_space_empty\n\n    # 1 if it is a valid player push destination\n    # (which also means it currently has a box)\n    push_mask = move_space_isbox .&amp; push_space_empty\n\n    # new player location\n    mask = move_mask .| push_mask\n    player_new = mask .| (player .* shift_tensor(.!mask, -d_row, -d_col, false))\n\n    # new box location\n    box_destinations = shift_tensor(boxes .* push_mask, d_row, d_col, false)\n    boxes_new = (boxes .* (.!push_mask)) .| box_destinations\n\n    return player_new, boxes_new\nend<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>The method appropriately moves any player tile that has an open space in the neighboring tile, or any player tile that has a neighboring pushable box. We create both a new player tensor and a new box tensor.<\/p>\n\n\n\n<p>This may seem extremely computationally expensive &#8212; we&#8217;re operating on all tiles rather than on just the ones we care about. But GPUs are really good at exactly this, and it is much cheaper to let the GPU churn through that than wait for the transfer to\/from the CPU.<\/p>\n\n\n\n<p>The main complication here is that we&#8217;re using the same action across all boards. In a given instance, there are \\(s\\times b\\) boards in our tensor. We don&#8217;t want to be using the same action in all of them.<\/p>\n\n\n\n<p>Instead of sharding different actions to different boards, we&#8217;ll compute the results of all 4 actions and then index into the resulting state that we need:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"560\" height=\"321\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_09-01.png\" alt=\"\" class=\"wp-image-2958\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_09-01.png 560w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_09-01-300x172.png 300w\" sizes=\"auto, (max-width: 560px) 100vw, 560px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>Working with GPUs sure makes you think differently about things.<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function advance_boards(\n    inputs::AbstractArray{Bool}, # [h,w,f,s,b]\n    actions::AbstractArray{Int}) #       [s,b]\n\n    succ_u = advance_boards(inputs, -1,  0) # [h,w,s,d], [h,w,s,d]\n    succ_l = advance_boards(inputs,  0, -1)\n    succ_d = advance_boards(inputs,  1,  0)\n    succ_r = advance_boards(inputs,  0,  1)\n\n    size_u = size(succ_u[1])\n    target_dims = (size_u[1], size_u[2], 1, size_u[3:end]...)\n    player_news = cat(\n        reshape(succ_u[1], target_dims),\n        reshape(succ_l[1], target_dims),\n        reshape(succ_d[1], target_dims),\n        reshape(succ_r[1], target_dims), dims=3) # [h,w,a,s,d]\n    box_news = cat(\n        reshape(succ_u[2], target_dims),\n        reshape(succ_l[2], target_dims),\n        reshape(succ_d[2], target_dims),\n        reshape(succ_r[2], target_dims), dims=3) # [h,w,a,s,d]\n\n    actions_onehot = onehotbatch(actions, 1:4) # [a,s,d]\n    actions_onehot = reshape(actions_onehot, (1,1,size(actions_onehot)...)) # [1,1,a,s,d]\n\n    boxes_new = any(actions_onehot .&amp; box_news, dims=3)\n    player_new = any(actions_onehot .&amp; player_news, dims=3)\n\n    return cat(inputs[:,:,1:3,:,:], boxes_new, player_new, dims=3)\nend<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>We&#8217;re almost there. This updates the boards in-place. To get the new inputs tensor, we want to shift our boards in the sequence dimension, propagating successor boards to the next sequence index. However, we can&#8217;t just shift the entire tensor. We want to keep the goals and the initial states:<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" width=\"492\" height=\"377\" src=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_09-11.png\" alt=\"\" class=\"wp-image-2959\" style=\"width:438px;height:auto\" srcset=\"https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_09-11.png 492w, https:\/\/timallanwheeler.com\/blog\/wp-content\/uploads\/2024\/12\/2024-12-01_09-11-300x230.png 300w\" sizes=\"auto, (max-width: 492px) 100vw, 492px\" \/><\/figure>\n<\/div>\n\n\n<p><\/p>\n\n\n\n<p>The code for this amounts to a cat operation and some indexing:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function advance_board_inputs(\n    inputs::AbstractArray{Bool}, # [h,w,f,s,b]\n    actions::AbstractArray{Int}) #       [s,b]\n\n    inputs_new = advance_boards(inputs, actions)\n\n    # Right shift and keep the goal and starting state\n    return cat(inputs[:, :, :, 1:2, :],\n               inputs_new[:, :, :, 2:end-1, :], dims=4) # [h,w,f,s,b]\nend<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>And with that, we&#8217;re processing actions across entire batches!<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">Rollouts on the GPU<\/h2>\n\n\n\n<p>We can leverage this new propagation code to propagate our inputs tensor during a rollout. The policy and the inputs have to be on the GPU, which in Flux.jl can be done with gpu(policy). Note that this requires a CUDA-compatible GPU.<\/p>\n\n\n\n<p>A single iteration is then:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\"># Run the model\n# policy_logits are [4 \u00d7 s \u00d7 b]\n# nsteps_logits are [7 \u00d7 s \u00d7 b]\npolicy_logits_gpu, nsteps_logits_gpu = policy0(inputs_gpu)\n\n# Sample from the action logits using the Gumbel-max trick\nactions_gpu = argmax(policy_logits_gpu .+ gumbel_noise_gpu, dims=1)\nactions_gpu = getindex.(actions_gpu, 1) # Int64[1 \u00d7 s \u00d7 b]\nactions_gpu = dropdims(actions_gpu, dims=1) # Int64[s \u00d7 b]\n\n# Apply the actions\ninputs_gpu = advance_board_inputs(inputs_gpu, actions_gpu)<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>The overall rollout code just throws this into a loop and does some setup:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">function rollouts!(\n    inputs::Array{Bool, 5},      # [h\u00d7w\u00d7f\u00d7s\u00d7b]\n    gumbel_noise::Array{Float32, 3}, # [4\u00d7s\u00d7b]\n    policy0::SokobanPolicyLevel0,\n    s_starts::Vector{Board}, # [b]\n    s_goals::Vector{Board}) # [b]\n\n    policy0 = gpu(policy0)\n\n    h, w, f, s, b = size(inputs)\n\n    @assert length(s_starts) == b\n    @assert length(s_goals) == b\n\n    # Fill the goals into the first sequence channel\n    for (bi, s_goal) in enumerate(s_goals)\n        set_board_input!(inputs, s_goal, 1, bi)\n    end\n\n    # Fill the start states in the second sequence channel\n    for (bi, s_start) in enumerate(s_starts)\n        set_board_input!(inputs, s_start, 2, bi)\n    end\n\n    inputs_gpu = gpu(inputs)\n    gumbel_noise_gpu = gpu(gumbel_noise)\n\n    for si in 2:s-1\n\n        # Run the model\n        # policy_logits are [4 \u00d7 s \u00d7 b]\n        # nsteps_logits are [7 \u00d7 s \u00d7 b]\n        policy_logits_gpu, nsteps_logits_gpu = policy0(inputs_gpu)\n\n        # Sample from the action logits using the Gumbel-max trick\n        actions_gpu = argmax(policy_logits_gpu .+ gumbel_noise_gpu, dims=1)\n        actions_gpu = getindex.(actions_gpu, 1) # Int64[1 \u00d7 s \u00d7 b]\n        actions_gpu = dropdims(actions_gpu, dims=1) # Int64[s \u00d7 b]\n\n        # Apply the actions\n        inputs_gpu = advance_board_inputs(inputs_gpu, actions_gpu)\n    end\n\n    return cpu(inputs_gpu)\nend<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>There are several differences:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>The code is simpler. We only have a single loop, over the sequence length (number of steps to take). The content of that loop is pretty compact.<\/li>\n\n\n\n<li>The code does more work. We&#8217;re processing more stuff, but because it happens in parallel on the GPU, its okay. We&#8217;re also propagating all the way to the end of the sequence whether we need to or not. (The CPU code would check whether all boards had finished already).<\/li>\n<\/ol>\n\n\n\n<p>If we time how long it takes to doing a batch worth of rollouts before and after moving to the GPU, we get about a \\(60\\times\\) speedup. Our efforts have been worth it!<\/p>\n\n\n\n<h1 class=\"wp-block-heading\">Beam Search on the GPU<\/h1>\n\n\n\n<p>Rollouts aren&#8217;t the only thing we want to speed up. I want to use beam search to explore the space using the policy and try to find solutions. Rollouts might happen to find solutions, but beam search should be a lot better.<\/p>\n\n\n\n<p>The code ends up being basically the same, except a single goal and board is used to seed the entire batch (giving us a number of beams equal to the batch size), and we have to do some work to score the beams and then select which ones to keep:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code lang=\"julia\" class=\"language-julia\">unction beam_search!(\n    inputs::Array{Bool, 5},      # [h\u00d7w\u00d7f\u00d7s\u00d7b]\n    policy0::SokobanPolicyLevel0,\n    s_start::Board,\n    s_goal::Board)\n\n    policy0 = gpu(policy0)\n\n    h, w, f, s, b = size(inputs)\n\n    # Fill the goals and starting states into the first sequence channel\n    for bi in 1:b\n        set_board_input!(inputs, s_goal, 1, bi)\n        set_board_input!(inputs, s_start, 2, bi)\n    end\n\n    # The scores all start at zero\n    beam_scores = zeros(Float32, 1, b) |&gt; gpu # [1, b]\n\n    # Keep track of the actual actions\n    actions = ones(Int, s, b) |&gt; gpu # [s, b]\n\n    inputs_gpu = gpu(inputs)\n\n    # Advance the games in parallel\n    for si in 2:s-1\n\n        # Run the model\n        # policy_logits are [4 \u00d7 s \u00d7 b]\n        # nsteps_logits are [7 \u00d7 s \u00d7 b]\n        policy_logits, nsteps_logits = policy0(inputs_gpu)\n\n        # Compute the probabilities\n        action_probs = softmax(policy_logits, dims=1) # [4 \u00d7 s \u00d7 b]\n        action_logls = log.(action_probs) # [4 \u00d7 s \u00d7 b]\n\n        # The beam scores are the running log likelihoods\n        action_logls_si = action_logls[:, si, :]  # [4, b]\n        candidate_beam_scores = action_logls_si .+ beam_scores # [4, b]\n        candidate_beam_scores_flat = vec(candidate_beam_scores) # [4b]\n\n        # Get the top 'b' beams\n        topk_indices = partialsortperm(candidate_beam_scores_flat, 1:b; rev=true)\n\n        # Convert flat indices back to action and beam indices\n        selected_actions = (topk_indices .- 1) .\u00f7 b .+ 1  # [b] action indices (1 to 4)\n        selected_beams   = (topk_indices .- 1) .% b .+ 1  # [b] beam indices (1 to b)\n        selected_scores  = candidate_beam_scores_flat[topk_indices]  # [b]\n        inputs_gpu = inputs_gpu[:,:,:,:,selected_beams]\n\n        actions[si,:] = selected_actions\n\n        # Apply the actions to the selected beams\n        inputs = advance_board_inputs(inputs_gpu, actions)\n    end\n\n    return (cpu(inputs_gpu), cpu(actions))\nend<\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<p>This again results in what looks like way simpler code. The beam scoring and such is all done on tensors, rather than a bunch of additional for loops. It all happens on the GPU, and it is way faster (\\(23\\times\\)).<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">Conclusion<\/h2>\n\n\n\n<p>The previous blog post was about leveraging the GPU during training. This blog post was about leveraging the GPU during inference. We had to avoid expensive data transfers between the CPU and the GPU, and to achieve that had to convert non-trivial player movement code to computations amenable to the GPU. Going about that meant thinking about and structuring our code very differently, working across tensors and creating <em>more work<\/em> that the GPU could nonetheless complete faster.<\/p>\n\n\n\n<p>This post was a great example of how code changes based on the scale you&#8217;re operating at. Peter van Hardenberg gives a great talk about similar concepts in <a href=\"https:\/\/www.youtube.com\/watch?v=czzAVuVz7u4\">Why Can&#8217;t We Make Simple Software?<\/a>. How you think about a problem changes a lot based on problem scale and hardware. Now that we&#8217;re graduating from the CPU to processing many many boards, we have to think about the problem differently.<\/p>\n\n\n\n<p>Our inference code has been GPU-ized, so we can leverage it to speed up validation and solution search. It was taking me 20 min to train a network but 30 min to run beam search on all boards in my validation set. This change avoids that sorry state of affairs.<\/p>\n\n\n\n<p><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Last month I wrote about moving the Sokoban policy training code from CPU to GPU, yielding That significantly shortened both training time and the time it takes to compute basic validation metrics. It has not, unfortunately, significantly changed how long it takes to run rollouts, and relatedly, how long it takes to run beam search. [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"closed","ping_status":"","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[11,12,1],"tags":[],"class_list":["post-2922","post","type-post","status-publish","format-standard","hentry","category-deep-learning","category-sokoban","category-uncategorized"],"_links":{"self":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts\/2922","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/comments?post=2922"}],"version-history":[{"count":50,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts\/2922\/revisions"}],"predecessor-version":[{"id":2982,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/posts\/2922\/revisions\/2982"}],"wp:attachment":[{"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/media?parent=2922"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/categories?post=2922"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/timallanwheeler.com\/blog\/wp-json\/wp\/v2\/tags?post=2922"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}