Pwnable.kr - Sudoku
The challenge can be found here.
Without at least an attempt at the challenge, this writeup will probably not make much sense. This is done to keep at least somewhat in the spirit of pwnable.kr style writeups.
The problem specification
This isn’t really a pwn problem so much as it’s an algorithm problem.
We have to solve a Sudoku with an additional constraint that the sum of some cells must be either strictly larger or smaller than some value.
Sudoku solving is an extremely classic example of backtracking, and the additional constraints are probably just to prevent literally ripping a sudoku solving algorithm off of leetcode or the like.
This is given to us in a format like
Stage 1
[0, 7, 5, 0, 3, 0, 9, 4, 6]
[0, 8, 4, 0, 0, 5, 1, 0, 2]
[0, 0, 0, 6, 0, 1, 0, 8, 0]
[3, 0, 0, 0, 1, 0, 0, 6, 0]
[4, 0, 8, 5, 0, 3, 2, 7, 1]
[0, 1, 0, 4, 0, 0, 3, 5, 8]
[0, 0, 0, 1, 0, 0, 6, 9, 4]
[0, 0, 1, 0, 6, 0, 8, 2, 0]
[2, 0, 6, 3, 0, 9, 7, 1, 5]
- additional rule -
sum of the following numbers (at row,col) should be bigger than 20
(row,col) : (1,6)
(row,col) : (6,6)
(row,col) : (2,5)
(row,col) : (6,1)
Which is no problem; we just end up parsing this data as soon as possible.
Backtracking?
Backtracking is literally just depth first searching over the problem space, where neighbor nodes are just another “valid” state of Sudoku.
Verifying partially completed boards
There’s three main things we have to check (we’ll defer the sum constraint check to a filled board):
- Vertical Scanning
- Horizontal Scanning
- Square Scanning
Importantly (and probably somewhat obviously), we only have to check the value of the cell that we changed to see if it broke the vertical, horizonal or square rules. This is because we start with a partially solved board which we know is correct, so verifying everything else is wasted as we’re guaranteed that that portion of the board is correct.
Vertical and horizontal scanning are probably self explanatory, but square might be a little awkward.
Consider an arbitrary (i, j)
that’s 0 indexed for a 9x9 board (so i is an integer that’s at least 0 and at most 8, and the same for j).
Consider the top left 3x3 square, for example.
We know that any cell in that square is so that 0 <= i, j <= 2
.
Simialrly, for the middle 3x3 square, we know that 3 <= i, j <= 5
.
Notice that if we do i // 3
, generally, we get which 3x3 square i
belongs to, and similarly for j
.
This lets us enumerate from i // 3 * 3
to i // 3 * 3 + 3
, and similarly for j
, in order to verify that all cells are unique.
Further, note that when doing these scans, if we encounter a 0
, that number could in the future be anything; we’ll ignore those as we scan through.
Neighbor function
When ingesting the board, we can mark all 0
cells and keep track of them, and enumerate through them from left to right.
So, when we find a candidate value for a given cell, we explore the new tree of boards with that given cell, expanding on the next 0
cell in the list.
Implementation
Note that this is certainly not the most efficient way of doing things, but 5 seconds is insanely long; we can certainly get away with this suboptimal code.
In particular, we don’t use the additional constraints until we have a completed board; we could probably trim the tree of boards further if we are more clever with our use of the additional constraints.
## solver.py
import json
class Solver:
def __init__(self,
board,
constraint_cells,
constraint_val,
cumsum_larger,
logging=False
):
"""
constraint_cells are given to me as 1 indexed, just convert to 0 indexed so its easier
"""
self.constraint_cells = [(i-1, j-1) for (i, j) in constraint_cells]
self.constraint_val = constraint_val
self.board = board
self.cumsum_larger = cumsum_larger
self.to_visit = []
self.logging = logging
for i in range(len(board)):
for j in range(len(board[0])):
self.to_visit.append((i, j)) if board[i][j] == 0 else None
@staticmethod
def gen_val_dict():
return {i : 0 for i in range(10)}
def print(self, *argv):
if self.logging:
print(*argv)
def verify(self, i, j):
"""
Check to make sure all traditional rules pass for i, j
"""
# column check
val = self.gen_val_dict()
for i_alt in range(9):
val[self.board[i_alt][j]] += 1
# self.print('col check results', val)
del val[0]
if any(map(lambda x: x > 1, val.values())):
self.print('failed i check')
return False
# vertical check
val = self.gen_val_dict()
# self.print('vert check results', val)
for j_alt in range(9):
val[self.board[i][j_alt]] += 1
del val[0]
if any(map(lambda x: x > 1, val.values())):
self.print('failed j check')
return False
# square check
val = self.gen_val_dict()
bucket_i, bucket_j = (i // 3) * 3, (j // 3) * 3
# self.print('bucket', bucket_i, bucket_j)
for bi in range(bucket_i, bucket_i + 3):
for bj in range(bucket_j, bucket_j + 3):
# self.print('checking', bi, bj)
val[self.board[bi][bj]] += 1
del val[0]
# self.print('val', val)
if any(map(lambda x: x > 1, val.values())):
self.print('failed square check')
return False
return True
def verify_constraint(self):
"""
finally, after finding a board which satisfies the solution,
let's see if the constraint is verified
probably an inefficient way to do it but it might suffice
"""
cumsum = 0
for i, j in self.constraint_cells:
cumsum += self.board[i][j]
self.print('cumsum is', cumsum, 'against consraint val of', self.constraint_val, 'with cumsum_larger', self.cumsum_larger)
if self.cumsum_larger:
return cumsum > self.constraint_val
else:
return cumsum < self.constraint_val
def print_board(self):
self.print('current board state:')
[self.print(x) for x in self.board]
def json_board(self):
return str(json.dumps(self.board)).replace(' ', '')
def solve(self, which=0):
"""
Simple backtrack
"""
self.print('solving', which, 'of', len(self.to_visit))
if which == len(self.to_visit):
self.print('attempting verify_constraint...')
return self.verify_constraint()
self.print(self.to_visit[which])
i, j = self.to_visit[which]
for x in range(1, 10):
self.board[i][j] = x
self.print('set', i, j, 'to', x, 'for which with which value', which, self.to_visit[which])
self.print_board()
if self.verify(i, j) and self.solve(which+1):
return True
self.board[i][j] = 0
## get_flag.py
import time
from pwn import *
from solver import Solver
def parse_output(out):
board = []
constraints = []
constraint_val = None
cumsum_larger = None
for line in filter(len, out.split(b'\n')):
if line[0] == ord('['):
board.append(eval(line))
elif line[0] == ord('('):
constraints.append(eval(line.split(b':')[-1]))
elif line.startswith(b'sum of the following'):
constraint_val = eval(line.split()[-1])
if b'bigger' in line:
cumsum_larger = True
elif b'smaller' in line:
cumsum_larger = False
else:
print(line.decode())
assert 0
return board, constraints, constraint_val, cumsum_larger
def setup_io():
"""
Startup the process and skip the tutorial
"""
io = process('nc pwnable.kr 9016'.split())
print(io.recv().decode())
io.sendline()
io.sendline()
return io
def go_through_stage(io):
print(io.recv().decode())
out = io.recvuntil(b'solution?')
print('###output is\n', out.decode())
board, constraints, constraint_val, cumsum_larger = parse_output(out)
initial_board = board.copy()
solver = Solver(board, constraints, constraint_val, cumsum_larger)
solver.solve()
solver.print_board()
soln = solver.json_board().encode() + b'\n'
print('###sending')
print(soln)
io.send(soln)
print('###sent solution!')
print(io.recvuntil(b'cheking your solution...'))
io = setup_io()
for _ in range(100):
go_through_stage(io)
io.interactive()