maybe-checker bwctf 2024
Python’s support for structural pattern matching is really nice, this challenge wasn’t that interesting but I thought it would be cool to show how I used pattern matching to make some of the constraint resolving less of a pain.
from pwn import *
from z3 import *
from pprint import pprint
from capstone import *
cs = Cs(CS_ARCH_X86, CS_MODE_64)
exe = ELF("maybe_checker")
deffered_addrs = []
offsets = []
base_addr = 0x400000
flen = 48
flag = [BitVec(f"c{i}", 8) for i in range(flen)]
s = Solver()
for i in range(0x3c):
deffered_addrs.append(int.from_bytes(exe.read(0x402041 + i*0x9, 8), byteorder="little"))
offsets.append(int.from_bytes(exe.read(0x402040 + i*0x9, 1), byteorder="little"))
fstub = list(zip(offsets, deffered_addrs))
constraint_fns = []
for dfa in deffered_addrs:
fn_bytes = exe.read(dfa, 0x30)
instr_list = []
for (address, size, mnemonic, op_str) in cs.disasm_lite(fn_bytes, dfa):
if mnemonic == "ret":
break
instr_list.append((mnemonic, op_str))
constraint_fns.append(instr_list)
def get_offset_from_op_str(instr):
"al, byte ptr [rdi + 0xc]"
if "+" in instr:
return int(instr.split("+ ")[1][:-1].removeprefix("0x"),16)
else:
print("Unknown Instruction:",instr)
return 0
for (offset, addr, fn_instrs) in zip(offsets, deffered_addrs, constraint_fns):
# two manual casee
print(hex(addr))
if addr == 0x401210:
s.add(flag[0] == 0x62, flag[1] == 0x77, flag[2] == 0x63, flag[3] == 0x74, flag[4] == 0x66, flag[5] == 0x7b)
s.add(flag[0x2f] == 0x7f)
elif addr == 0x401240:
s.add(flag[3+8] == 0x2d , flag[9+8] == 0x2d , flag[0xf+8] == 0x2d , flag[0x15+8] == 0x2d , flag[0x1b+8] == 0x2d)
s.add(flag[0x21 + 8] == 0x2d)
else:
#print(hex(addr), fn_instrs)
match fn_instrs:
case [("mov", op1), ("cmp", op2), ("setl", op3)]:
#print("le case")
off1 = get_offset_from_op_str(op1)
off2 = get_offset_from_op_str(op2)
s.add(flag[offset + off1] < flag[offset + off2])
case [("mov", op1), ("xor", op2), ("cmp", op3), ("sete", _)]:
off1 = get_offset_from_op_str(op1)
off2 = get_offset_from_op_str(op2)
expected = int(op3.split("0x")[1],16)
s.add((flag[offset + off1] ^ flag[offset + off2]) == expected)
case [("movsx", op1), ("movsx", op2), ("imul", "ecx, eax"), ("cmp", eql), ("sete", _)]:
off1 = get_offset_from_op_str(op1)
A = ZeroExt(24, flag[offset + off1]) #extend to 32, or 4byte
off2 = get_offset_from_op_str(op2)
B = ZeroExt(24, flag[offset + off2]) #extend to 32, or 4byte
val = int(eql.split(", ")[1].removeprefix('0x'), 16)
s.add((A*B) == val)
case [("mov", op1), ("cmp", op2), ("setg", op3)]:
off1 = get_offset_from_op_str(op1)
off2 = get_offset_from_op_str(op2)
s.add(flag[offset + off1] > flag[offset + off2])
case [("mov", op1), ("cmp", op2), ("sete", op3)]:
off1 = get_offset_from_op_str(op1)
off2 = get_offset_from_op_str(op2)
s.add(flag[offset + off1] == flag[offset + off2])
case [("movsx", op1), ("movsx", op2), ("add", "ecx, eax"), ("cmp", eql), ("sete", _)]:
off1 = get_offset_from_op_str(op1)
A = ZeroExt(24, flag[offset + off1]) #extend to 32, or 4byte
off2 = get_offset_from_op_str(op2)
B = ZeroExt(24, flag[offset + off2]) #extend to 32, or 4byte
val = int(eql.split(", ")[1].removeprefix('0x'), 16)
s.add((A+B) == val)
case _:
print("Uncaught case:")
print(hex(addr), fn_instrs)
exit()
m = s.model()
flg = ""
for i in flag:
flg += chr(m[i].as_long())
print(flg)