I’ve posted about Sokoban before. It is a problem I keep coming back to.
Sokoban is both very simple to specify but its difficulty can be readily scaled up. It is easy to accidentally push blocks into an inescapable kerfuffle, and more interestingly, it isn’t always obvious when you’ve done it.
Moreover, I think people solving Sokoban often beautifully exhibit hierarchical reasoning. The base game is played with up/down/left/right actions, but players very quickly start thinking about sequences of block pushes. We learn to “see” deadlock configurations and reason about rooms and packing orders. This ability to naturally develop higher-order thinking and execute on it is something I want to understand better. So I keep coming back to it.
In this blog post, I’ll recount training a Sokoban policy using deep learning, and more specifically, using the transformer architecture. I started pretty small, developing something I that hope to flesh out more in the future.
The Problem
This blog post is about policy optimization. A policy \(\pi(s)\) maps states to actions. These states and actions must exist in some sort of problem — i.e. a search problem or Markov decision process. It is often a good idea to define the problem before talking about how to structure the policy.
Our problem, of course, is Sokoban. The state space is the space of Sokoban boards. Each board is a grid, and each tile in the grid is either a wall or a floor, each floor may be a goal and may contain either a box or the player. For this post I’m working with 8×8 boards.
\[\text{state space}\enspace\mathcal{S} = \text{Sokoban boards}\]
Normal Sokoban has an action space of up/down/left/right. At least, that is what we’re told. I think that isn’t quite true. In normal Sokoban, the player can get stuck in a deadlock. When that happens, they typically reset the level or hit undo. So I’m going use an action space of up/down/left/right/undo. Trajectories with undos are of course never optimal, but you nevertheless might need to use them when you’re exploring.
\[\text{action space}\enspace\mathcal{A} = \left\{\text{up}, \text{down}, \text{left}, \text{right}, \text{undo}\right\}\]
Notice that we’re working directly with player movements, not with box pushes. Pretty much no-one ever works directly with player movements, because the search space gets so large so quickly. That being said, hierarchical reasoning is the point of my investigations, so I’m starting at the bottom of the reasoning hierarchy.
Our goal is to solve Sokoban problems. For now, let’s just say that we get a reward of 1 whenever we solve a board (all boxes are on goals), and zero reward otherwise. Given two action sequences that solve a board, we prefer the shorter one.
\[R(s,a) = \begin{cases}1 & \text{if } \texttt{issolved}(s) \\ 0 & \text{otherwise}\end{cases}\]
We won’t directly make use of this reward function in this blog post, but it will come into play later.
The Policy
A typical policy would take the current state and try to find the best action that leads toward having all boxes on goals. Its architecture would broadly look like:
The output is a probability distribution over the actions. There are 5 actions (up,down,left,right, and undo), so this is just a categorical distribution over 5 values. That’s the bread and butter of ML and we can easily produce that with a softmax output layer.
The input is a board state. I decided to represent boards as \(8\times 8\times 5\) tensors, where the first two dimensions are the board length and width, and then we have 5 channels that indicate whether a given tile is a wall space, a floor space, has a goal, has a player, and has a box. Altogether that is 320 scalars, which is somewhat long and sparse. However, it contains a lot of structured spatial information that we need to make use of.
The first architectural decision is to run the board state through an encoder head to produce a latent vector. Its basically a 5-channel image, so I’m using convolutional layers, just like you would for an image input. These can learn consistent features that get applied in parallel across the input.
This could work. We’d just stick some feedforward layers with maybe some residual connections into that trunk, and we’d be able to train a policy to place all boxes on goals. I didn’t quite want to go that route though, because I wanted the network to be slightly more configurable.
My eventual goal is to do hierarchical reasoning, so I want to be able to have another, higher-level policy use this lower-level policy as a tool. As such, I want an input that it can provide in order to steer the policy. To this end, I’m providing an additional goal input that tells the neural network what we want to happen:
In Sokoban its pretty important for the player to know when a board just isn’t solvable. Its really easy to end up in a deadlock state:
We want the model to be able to learn when this happens, and to be able to report that it thinks this is the case.
I thus include an additional output, which I’m calling the nsteps
output. It predicts the number of steps left in the solution of the puzzle. If the network thinks there are an infinite number of steps left, then the board is not solvable.
In Stop Regressing, the authors suggest avoiding (continuous) regression and to train over discrete targets instead. As such, my nsteps output will actually output a discrete value:
nsteps = Int(clamp(round(log(nsteps+1)), 0, 5)) + 1
which can take on up to 6 values:
and a 7th bin for unsolvable.
That leaves our network as (with the softmaxes implied):
This architecture is a perfectly reasonable policy network. In fact, we’ll compare against this network later on. However, I was also very interested in working with transformers, so I built a transformer version of this network:
This architecture is roughly the same, except you have the massively parallel transformer that:
- can train with supervised learning more efficiently, by receiving an entire sequence at once
- can use past information to inform future decisions
- can easily be scaled up or down based on just a few parameters
- lends itself nicely to beam search for exploration
I am initially working with relatively small models consisting of 3-layer transformers with 16-dimensional embedding spaces, a maximum sequence length of 32, and 8-headed attention.
Training
This initial model is trained with supervised learning. In order to do that, we need supervised data.
Typical supervised learning would provide a starting state, the goal state, and an optimal set of actions to reach it. My training examples are basically that, but also include examples that cannot be solved and some examples with bad moves and a training mask that allows me to mask them out during training. (This theoretically lets the network learn undos and to recover from bad moves).
start:
########
# #
# #
# . bp #
# #
# #
# #
########
goal:
########
# #
# #
# B #
# #
# #
# #
########
solved
moves: LL
isbad: --
solved
moves: xLL
isbad: x--
solved
moves: UxLL
isbad: xx--
I have a few handcrafted training entries, but most of them are randomly generated problems solved with A*. I generate 1-box problems by starting with an 8×8 grid of all-walls, producing two random rectangular floor sections (rooms), randomly spawning the goal, box, and the player, and then speckling in some random walls. That produces a nice diversity that has structure, but not too much structure.
Most of these boards aren’t solvable, so I when I generate training examples, I set a target ratio for solved and unsolved boards and make sure to keep generating until I get enough of each type. I have 4000 generated problems with solutions and 4000 generated problems without solutions, resulting in 8,000 overall generated training examples.
I was able to cheaply boost the size of the training set by recognizing that a problem can be rotated 4 ways and can be transposed. Incorporating these transforms (and applying the necessary adjustments to the solution paths) automatically increases the number of training examples by 8x.
I train with a standard Flux.jl loop using the default Adam optimizer. I shuffle the training set with each epoch, and train with a batch size of 32. I’m using 1% dropout.
The training loop is set up to create a new timestamped directory with each run and save a .json file with basic metrics. I compute both the policy and nsteps loss on each batch and log these for later printing. This gives me nice training curves, effectively using the code from the Flux quick start guide:
This plot uses a log-scale x-axis, which gives it some distortion. Training curves are often plotted this way because a lot of stuff happens early on and more updates have to be be made later on to get similar movement.
I’ve been consistently seeing the plots play out like they do above. At first the loss drops very quickly, only to plateau between 100 and 1000 batches, and then to suddenly start reducing again. The model seems to suddenly understand something around the 1000-batch mark. Based on how flat the policy loss is, it seems to be related to that. Maybe it learns some reliable heuristics around spatial reasoning and walking to boxes.
The curves suggest that I could continue to train for longer and maybe get more improvement. Maybe that’s true, and maybe I’ll look into that in the future.
Validation
I have two types of validation – direct supervised metrics and validation metrics based on using the policy to attempt to solve problems. Both of these are run on a separate validation dataset of 4000 randomly generated problems (with a 50-50 split of solvable and not solvable).
The supervised metrics are:
- accuracy at predicting whether the root position is solvable
- top-1 policy accuracy – whether the highest-likelihood action matches the supervised example
- top-2 policy accuracy – whether the supervised example action is in the top two choices
- top-1 nsteps accuracy – whether the highest-likelihood nsteps bucket matches the example
- top-2 nsteps accuracy – ditto, but top two choices
Note that a Sokoban problem can have multiple solutions, so we shouldn’t expect a perfect policy accuracy. (i.e. going up and then left is often the same as going left and then up).
I then use the model to solve each problem in the validation dataset using beam search. This lets me actually test efficacy. Beam search was run with a beam width of 32, which means I’m always going to find the solution if it takes fewer than 2 steps. The search keeps the top candidates based on the running sum of the policy likelihood. (I can try other formulations, like scoring candidates based on the predicted nsteps values).
Results
I trained two models – a transformer model and a 2nd model that does not have access to past data. The 2nd model is also a transformer, but with a mask that masks out everything but the current and goal states.
The metrics are:
Transformer | Baseline | |
Solvability Accuracy | 0.977 | 0.977 |
Top-1 Policy Accuracy | 0.867 | 0.831 |
Top-2 Policy Accuracy | 0.982 | 0.969 |
Top-1 Nsteps Accuracy | 0.900 | 0.898 |
Top-2 Nsteps Accuracy | 0.990 | 0.990 |
Solve Rate | 0.921 | 0.879 |
Mean Solution Length | 6.983 | 6.597 |
Overall, that’s pretty even. The baseline policy that does not have access to past information is ever so slightly worse with respect to top-1 accuracy. Our takeaway is that for this problem, so far, the past doesn’t matter all that much and we can be pretty Markovian.
A 97.7% solvability accuracy is pretty darn good. I am happy with that.
The top-k policy and nsteps accuracies are pretty good too. The full transformer outperforms the baseline a bit here, but only by a bit.
The primary difference comes out when we look at the solve rate, which we obtain by actually using the policies to solve problems. The full transformer succeeds 92.1% of the time, whereas when there is no history it only solves it 87.9% of the time. That’s only a 4.2% spread, but it is noticeable. I can only hypothesize, but I suspect past information is useful for longer traversals and simply the fact that you can pull more information in can give the model more space to compute with.