Pwnable.kr - Sudoku

Alex Sieusahai · April 4, 2022

Pwnable.kr - Sudoku

The challenge can be found here.

Without at least an attempt at the challenge, this writeup will probably not make much sense. This is done to keep at least somewhat in the spirit of pwnable.kr style writeups.

The problem specification

This isn’t really a pwn problem so much as it’s an algorithm problem.

We have to solve a Sudoku with an additional constraint that the sum of some cells must be either strictly larger or smaller than some value.

Sudoku solving is an extremely classic example of backtracking, and the additional constraints are probably just to prevent literally ripping a sudoku solving algorithm off of leetcode or the like.

This is given to us in a format like

Stage 1

[0, 7, 5, 0, 3, 0, 9, 4, 6]
[0, 8, 4, 0, 0, 5, 1, 0, 2]
[0, 0, 0, 6, 0, 1, 0, 8, 0]
[3, 0, 0, 0, 1, 0, 0, 6, 0]
[4, 0, 8, 5, 0, 3, 2, 7, 1]
[0, 1, 0, 4, 0, 0, 3, 5, 8]
[0, 0, 0, 1, 0, 0, 6, 9, 4]
[0, 0, 1, 0, 6, 0, 8, 2, 0]
[2, 0, 6, 3, 0, 9, 7, 1, 5]

- additional rule -
sum of the following numbers (at row,col) should be bigger than 20
(row,col) : (1,6)
(row,col) : (6,6)
(row,col) : (2,5)
(row,col) : (6,1)

Which is no problem; we just end up parsing this data as soon as possible.

Backtracking?

Backtracking is literally just depth first searching over the problem space, where neighbor nodes are just another “valid” state of Sudoku.

Verifying partially completed boards

There’s three main things we have to check (we’ll defer the sum constraint check to a filled board):

  1. Vertical Scanning
  2. Horizontal Scanning
  3. Square Scanning

Importantly (and probably somewhat obviously), we only have to check the value of the cell that we changed to see if it broke the vertical, horizonal or square rules. This is because we start with a partially solved board which we know is correct, so verifying everything else is wasted as we’re guaranteed that that portion of the board is correct.

Vertical and horizontal scanning are probably self explanatory, but square might be a little awkward.

Consider an arbitrary (i, j) that’s 0 indexed for a 9x9 board (so i is an integer that’s at least 0 and at most 8, and the same for j). Consider the top left 3x3 square, for example. We know that any cell in that square is so that 0 <= i, j <= 2. Simialrly, for the middle 3x3 square, we know that 3 <= i, j <= 5.

Notice that if we do i // 3, generally, we get which 3x3 square i belongs to, and similarly for j. This lets us enumerate from i // 3 * 3 to i // 3 * 3 + 3, and similarly for j, in order to verify that all cells are unique.

Further, note that when doing these scans, if we encounter a 0, that number could in the future be anything; we’ll ignore those as we scan through.

Neighbor function

When ingesting the board, we can mark all 0 cells and keep track of them, and enumerate through them from left to right.

So, when we find a candidate value for a given cell, we explore the new tree of boards with that given cell, expanding on the next 0 cell in the list.

Implementation

Note that this is certainly not the most efficient way of doing things, but 5 seconds is insanely long; we can certainly get away with this suboptimal code.

In particular, we don’t use the additional constraints until we have a completed board; we could probably trim the tree of boards further if we are more clever with our use of the additional constraints.

## solver.py
import json

class Solver:
    def __init__(self,
                 board,
                 constraint_cells,
                 constraint_val,
                 cumsum_larger,
                 logging=False
                 ):
        """
        constraint_cells are given to me as 1 indexed, just convert to 0 indexed so its easier
        """
        self.constraint_cells = [(i-1, j-1) for (i, j) in constraint_cells]
        self.constraint_val = constraint_val
        self.board = board
        self.cumsum_larger = cumsum_larger
        self.to_visit = []
        self.logging = logging
        for i in range(len(board)):
            for j in range(len(board[0])):
                self.to_visit.append((i, j)) if board[i][j] == 0 else None

    @staticmethod
    def gen_val_dict():
        return {i : 0 for i in range(10)}

    def print(self, *argv):
        if self.logging:
            print(*argv)

    def verify(self, i, j):
        """
        Check to make sure all traditional rules pass for i, j
        """
        # column check
        val = self.gen_val_dict()
        for i_alt in range(9):
            val[self.board[i_alt][j]] += 1
        # self.print('col check results', val)
        del val[0]
        if any(map(lambda x: x > 1, val.values())):
            self.print('failed i check')
            return False

        # vertical check
        val = self.gen_val_dict()
        # self.print('vert check results', val)
        for j_alt in range(9):
            val[self.board[i][j_alt]] += 1
        del val[0]
        if any(map(lambda x: x > 1, val.values())):
            self.print('failed j check')
            return False

        # square check
        val = self.gen_val_dict()
        bucket_i, bucket_j = (i // 3) * 3, (j // 3) * 3
        # self.print('bucket', bucket_i, bucket_j)
        for bi in range(bucket_i, bucket_i + 3):
            for bj in range(bucket_j, bucket_j + 3):
                # self.print('checking', bi, bj)
                val[self.board[bi][bj]] += 1
        del val[0]
        # self.print('val', val)
        if any(map(lambda x: x > 1, val.values())):
            self.print('failed square check')
            return False

        return True

    def verify_constraint(self):
        """
        finally, after finding a board which satisfies the solution,
        let's see if the constraint is verified

        probably an inefficient way to do it but it might suffice
        """
        cumsum = 0
        for i, j in self.constraint_cells:
            cumsum += self.board[i][j]
        self.print('cumsum is', cumsum, 'against consraint val of', self.constraint_val, 'with cumsum_larger', self.cumsum_larger)
        if self.cumsum_larger:
            return cumsum > self.constraint_val
        else:
            return cumsum < self.constraint_val

    def print_board(self):
        self.print('current board state:')
        [self.print(x) for x in self.board]

    def json_board(self):
        return str(json.dumps(self.board)).replace(' ', '')

    def solve(self, which=0):
        """
        Simple backtrack
        """
        self.print('solving', which, 'of', len(self.to_visit))
        if which == len(self.to_visit):
            self.print('attempting verify_constraint...')
            return self.verify_constraint()
        self.print(self.to_visit[which])

        i, j = self.to_visit[which]
        for x in range(1, 10):
            self.board[i][j] = x
            self.print('set', i, j, 'to', x, 'for which with which value', which, self.to_visit[which])
            self.print_board()
            if self.verify(i, j) and self.solve(which+1):
                return True
        self.board[i][j] = 0
## get_flag.py
import time
from pwn import *
from solver import Solver

def parse_output(out):
    board = []
    constraints = []
    constraint_val = None
    cumsum_larger = None
    for line in filter(len, out.split(b'\n')):
        if line[0] == ord('['):
            board.append(eval(line))
        elif line[0] == ord('('):
            constraints.append(eval(line.split(b':')[-1]))
        elif line.startswith(b'sum of the following'):
            constraint_val = eval(line.split()[-1])
            if b'bigger' in line:
                cumsum_larger = True
            elif b'smaller' in line:
                cumsum_larger = False
            else:
                print(line.decode())
                assert 0

    return board, constraints, constraint_val, cumsum_larger

def setup_io():
    """
    Startup the process and skip the tutorial
    """
    io = process('nc pwnable.kr 9016'.split())
    print(io.recv().decode())
    io.sendline()
    io.sendline()
    return io

def go_through_stage(io):
    print(io.recv().decode())
    out = io.recvuntil(b'solution?')
    print('###output is\n', out.decode())
    board, constraints, constraint_val, cumsum_larger = parse_output(out)
    initial_board = board.copy()
    solver = Solver(board, constraints, constraint_val, cumsum_larger)
    solver.solve()
    solver.print_board()
    soln = solver.json_board().encode() + b'\n'
    print('###sending')
    print(soln)
    io.send(soln)
    print('###sent solution!')
    print(io.recvuntil(b'cheking your solution...'))

io = setup_io()
for _ in range(100):
    go_through_stage(io)
io.interactive()

Twitter, Facebook