あさっちの不定期日記

色々ごった煮。お勉強の話もあればテニスの話をするかもしれない。

Hack.lu CTF Crypto Writeup

はじめに

Hack.lu CTFにちょっとだけ参加しました。チームメイトが頑張ってくれたおかげで16位/171チームでした。

crypto3問(一番難しいrevとの複合問題は除く。偉い人が書いてくれたらそれを使って自分も復習します)のWriteupを例によってやや詳しめに書きます。

Silver Water Industries (92 solves, 1時間)

The local water supplier Silver Water Industries is planning their IPO. To appeal to current crypto investors, they even implemented a military grade token encryption.

package main

import (
    "bufio"
    "crypto/rand"
    "fmt"
    "math"
    "math/big"
    "os"
)

func genN() *big.Int {
    var p *big.Int
    var q *big.Int
    var err error

    for {
        p, err = rand.Prime(rand.Reader, 64)
        if err != nil {
            panic(err)
        }
        res := new(big.Int)
        if res.Mod(p, big.NewInt(4)); res.Cmp(big.NewInt(1)) == 0 {
            break
        }
    }

    for {
        q, err = rand.Prime(rand.Reader, 64)
        if err != nil {
            panic(err)
        }
        res := new(big.Int)
        if res.Mod(q, big.NewInt(4)); res.Cmp(big.NewInt(3)) == 0 {
            break
        }
    }

    N := new(big.Int)
    N.Mul(p, q)
    return N
}

func genX(N *big.Int) *big.Int {
    for {
        x, err := rand.Int(rand.Reader, N)
        if err != nil {
            panic(err)
        }
        g := new(big.Int)
        g.GCD(nil, nil, x, N)
        if g.Cmp(big.NewInt(1)) == 0 {
            return x
        }
    }
}

func encryptByte(b uint8, N *big.Int) []*big.Int {
    z := big.NewInt(-1)
    enc := make([]*big.Int, 8)
    for i := 0; i < 8; i++ {
        bit := b & uint8(math.Pow(2, float64(7-i)))
        x := genX(N)
        x.Exp(x, big.NewInt(2), N)
        if bit != 0 {
            x.Mul(x, z)
            x.Mod(x, N)
        }
        enc[i] = x
    }
    return enc
}

func generateRandomString(n int) string {
    const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-"
    ret := make([]byte, n)
    for i := 0; i < n; i++ {
        num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
        if err != nil {
            panic(err)
        }
        ret[i] = letters[num.Int64()]
    }

    return string(ret)
}

func main() {
    N := genN()

    token := []byte(generateRandomString(20))

    fmt.Println(N)
    for _, b := range token {
        fmt.Println(encryptByte(uint8(b), N))
    }
    fmt.Println("")

    reader := bufio.NewReader(os.Stdin)

    input, err := reader.ReadString('\n')
    if err != nil {
        panic(err)
    }
    input = input[:len(input)-1]

    if string(token) == input {
        fmt.Println("flag{<YOUR_FLAG_HERE>}")
    }
}

Go言語ド素人なので読むまで苦労しました。

なお、この問題はチームリーダーが解いてくれました。解法は特に見ずに自分も後で解きなおしてみた感じになります。

  • $ p \equiv 1 \ {\rm mod} \ 4$かつ$ q \equiv 3 \ {\rm mod} \ 4$なる64ビットの素数を生成し、$ N = pq$が与えられる。当然ながら$ N \equiv 3 \ {\rm mod} \ 4$

  • 20文字のランダム文字列$ r$の各ビット$ r_{i}$に対して、毎回ランダムな値$x_i$をとってきて

\begin{align} C_i \equiv (x_i) ^ 2 \ {\rm mod} \ N \ \ {\rm if } \ r_i = 0\\ C_i \equiv - (x_i) ^ 2 \ {\rm mod} \ N \ \ {\rm if } \ r_i = 1 \end{align}

と暗号化。

このもとで$ r_i$を求めよ($ r$を復元せよ)、という問題です。

素因数分解

$ p, q$は64ビットの素数なので、SageMathのfactorで十分高速に素因数分解が可能です。なので素数は実質既知であるとしてよいです。

N = ...
p_ls, q_ls = list(facotr(N))  # [(p, 1), (q, 1)]のような構造。つまりN = (p^1) * (q^1)
p, q = p_ls[0], q_ls[0]

平方剰余

こういうのは経験則になってしまうので試行錯誤うんぬんの話はあまりできなくて申し訳ないのですが、

ファーストインプレッションとしては「片方が平方数で、もう片方が平方数じゃないような形で場合分けしているので、おそらく平方剰余が絡んでいるんだろうな」と感じました。

結局のところ、これがもろ正解で、先に結論だけ言ってしまうとJacobi記号を用いて

 
\begin{pmatrix} 
C_i \\ 
N
\end{pmatrix} 
=
\begin{pmatrix} 
C_i \\ 
p
\end{pmatrix} 
\begin{pmatrix} 
C_i \\ 
q
\end{pmatrix}

の値(はてブで表記が面倒なので便宜上$ J(C_i)$と表記します)が$ 1$なら$ r_i=0$、$ J(C_i) = -1$なら$ r_i = 1$です。

これだけで$ r_i$の値はわかります。おしまい。

簡単な証明を最後に載せておきます。興味がある人は上のJacobi記号のリンク(Wikipedia)と一緒にどうぞ。

ちなみに、Jacobi記号は任意の非負整数$ N$を法としての平方剰余ですが、これが素数の場合はLegendre記号と同義になり、これはSageではkroneckerで計算可能です。

(予め素因数分解されたものを引数にとっていますが、内部でfactorを走らせても大丈夫です)

def jacobi(a, p, q):
    res = kronecker(a, p) * kronecker(a, q)
    return res

(追記:Jacobi記号もSageMathにしっかりありました。ふるつきさんありがとうございます。)

# どちらでも動作確認しました
jacobi_symbol(a, n)
a.jacobi(n)

Writeup

# solver.sage

from pwn import *
    
def jacobi(a, p, q):
    res = kronecker(a, p) * kronecker(a, q)
    if res == 1 or res == -1:
        return res
    # 例外処理
    else:
        print('jacobi error, {}'.format(res))
        assert 1==2
        return 0

def decrypt(ls, p, q):
    b = ''
    for l in ls:
        # 各ビットに対してJacobi記号で平方剰余かどうかの判定
        if jacobi(l, p, q) == 1:
            b += '0'
        elif jacobi(l, p, q) == -1:
            b += '1'
        # 例外処理
        else:
            print('decrypt error, {}'.format(jacobi(l, p, q)))
            assert 1==2
    return chr(int(b, 2))

# C_iのパース    
def parse_bytes(b:bytes):
    res = b.decode()[:-1]
    # assert res[0] == '[' and res[-1] == ']'
    res = res[1:-1]
    res = [int(_) for _ in res.split(' ')]
    return res

conn = remote('flu.xxx', 20060)
n = int(conn.recvline())
# 素因数分解
p, q = list(factor(n))
p, q = int(p[0]), int(q[0])
assert p * q == n

token = ''
# 復号
for i in range(20):
    b = conn.recvline()
    ls = parse_bytes(b)
    token += decrypt(ls, p, q)

conn.recvline()
conn.sendline(token.encode())
print(conn.recvline())

flag{Oh_NO_aT_LEast_mY_AlGORithM_is_ExpanDiNg}

証明

$ a, p$が互いに素であるというのは暗黙の了解とします。大事なのは

  1. $ J(a) = -1$なら$ a$は平方非剰余 ($ J(a) = 1$であっても$ a$が平方剰余かどうかは不明)

  2. $ N \equiv 3 \ {\rm mod} \ 4$から、$ J(-C_i) = J(-1) \cdot J(C_i) = -J(C_i)$

という2点です。

$ J(C_i) = 1$ならば$r_i = 0$の証明:

$ J(C_i) = 1$かつ$r_i = 1$で矛盾を示せばよいです。

2.から$ J(-C_i) = -J(C_i) = -1$となるので、1.とあわせて$ -C_i$は平方非剰余。

しかし、$r_i = 1$と仮定しているので$ C_i = - (x_i) ^ 2 \ \Leftrightarrow -C_i = (x_i) ^ 2$なる$ x_i$が存在し、矛盾。

$ J(C_i) = -1$ならば$r_i = 1$の証明:

「$ r_i = 0$ならば$ C_i$が平方剰余」なのはアルゴリズムから自明です。

また、1.の主張の対偶をとって「$ a$が平方剰余ならば$ J(a) = 1$」が成立します。

この2つから、「$ r_i = 0$ならば$ J(C_i) = 1$」が成立します。この対偶をとればよいです。

WhatTheHecc (45 solves, 1時間40分)

Go hecc it!

#!/usr/bin/env python3
import sys
import shlex
import subprocess
from Cryptodome.PublicKey import ECC
from Cryptodome.Hash import SHA3_256
from Cryptodome.Math.Numbers import Integer
import time 

# util

def run_cmd(cmd):
    try:
        args = shlex.split(cmd)
        return subprocess.check_output(args).decode('utf-8')
    except Exception as ex:
        return str(ex)

def read_message():
    return sys.stdin.readline()

def send_message(message):
    sys.stdout.write('### {0}\r\n>'.format(message))
    sys.stdout.flush()

# crypto stuff

def hash(msg):
    h_obj = SHA3_256.new()
    h_obj.update(msg.encode())
    return Integer.from_bytes(h_obj.digest())

def setup(curve):
    key = ECC.generate(curve=curve)
    return key

def blind(msg, pub):
    r = pub.pointQ * hash(msg)
    return r

def sign(r, key):
    r_prime = r * key.d.inverse(key._curve.order)

    date = int(time.time())
    nonce = Integer.random_range(min_inclusive=1,max_exclusive=key._curve.order)
    z = f'{nonce}||{date}'

    R = r_prime + (key._curve.G * hash(z))
    s = (key.d - hash(z)) % key._curve.order
    # return (R, s, z)
    # we can not give away z or this is unsafe: x = s+h(z)
    return R, s

def verify(msg, sig, pub):
    R, s = sig

    if s in [0,1,''] and s > 0:
        return False

    tmp1 = s * pub._curve.G
    tmp2 = - pub.pointQ 
    tmp3 = tmp2 + R

    return tmp1 + tmp3 == hash(msg) * pub._curve.G

## ok ok here we go

def main():
    while True:
        send_message('Enter your command:')
        cmd = read_message().strip()

        if cmd == 'sign':
            send_message('Send cmd to sign:')
            cmd = read_message().strip()

            if(cmd in ['id', 'uname', 'ls', 'date']):
                r = blind(cmd, pubkey)
                sig = sign(r, key)
                
                send_message(f'Here you go: {sig[0].x}|{sig[0].y}|{sig[1]}|{cmd}')
            else:
                send_message('Not allowed!')

        elif cmd == 'run':
            send_message('Send sig:')
            sig = read_message().strip()
            tmp = sig.split('|')
            if len(tmp) == 4:
                x = int(tmp[0])
                y = int(tmp[1])
                s = int(tmp[2])
                c = tmp[3]
                sig = (ECC.EccPoint(x, y, curve='P-256'), s)
                if(verify(c, sig, pubkey)):
                    out = run_cmd(c)
                    send_message(out)
                else:
                    send_message('Invalid sig!')
            else:
                send_message('Invalid amount of params!')

        elif cmd == 'show':
            send_message(pubkey)

        elif cmd == 'help':
            send_message('Commands: exit, help, show, run, sign')

        elif cmd == 'exit':
            send_message('Bye :) Have a nice day!')
            break

        else:
            send_message('Invalid command!')

if __name__ == '__main__':
    key = setup('P-256')
    pubkey = key.public_key()
    main()

長いですが、簡単にまとめると以下のようになります。

  • P-256という楕円曲線を使う

  • 秘密鍵$ d$に対応する公開鍵$Q = d \cdot G$は与えられる($ G$はP-256のベースポイント、上のリンク参照)

  • メッセージ Mの署名は、ランダム値$ k$を選んで

$$ R = (H(M) + k) \cdot G $$

$$ S = d - k $$

として$ R, S$を出力する。$ H(M)$はSHA3-256を使っていて、我々が Mとして入力可能なのは'id', 'uname', 'ls', 'date'の4種類のみ

  • 署名検証は

$$ S \cdot G - Q + R = H(M) \cdot G $$

が成立するかどうかをもって行う。

このもとで、メッセージ'cat flag'に対する署名$ (r ^ \ast, s ^ \ast)$を偽造する問題です。

大文字は既知で、小文字が未知数です(こうしてしまったせいで、どれが数字でどれが楕円曲線上の点なのか見にくくなりました。すみません)。

アプローチ1

ファーストインプレッションとしてはECDSAみたいなことをしているなという感じでした。

まずは秘密鍵$ d$を求める方法を考えてみます。

同じメッセージでも$ k$がランダム値ということに注目して

$$ R_1 = (H(M) + k_1) \cdot G, \ S_1 = d - k_1 $$

$$ R_2 = (H(M) + k_2) \cdot G, \ S_2 = d - k_2 $$

が成り立ちます。$ R_1 - R_2 = (k_1 - k_2) G$となり、一般に離散対数問題は解けないですが、今回は$ S_2 - S_1 = k_1 - k_2$とすることで実は求まります。

…?

…!

ここで天啓。気付きます。

「$ H(M)$を固定して$ k$の差分をとるのではなく、$k_1$を固定して$ H(M)$の差分を考えればいいのでは?」

アプローチ2

偽造したいメッセージ'cat flag' M ^ {\ast}とします。

$ k_1$の固定は難しいことではなく、$ S_1$をそのまま使いまわせばいいだけです。

なので、何とかして$ (r ^ \ast , S_1)$が M ^ \astの署名になるような$r ^ \ast$を偽造したいところです。

ここで$ \Delta = H(M ^ {\ast}) - H(M)$とおくと

$$ R_1 + \Delta \cdot G $$

$$ = (H(M) + k_1) \cdot G + \left(H(M ^ {\ast}) - H(M) \right) \cdot G $$

$$ = (H(M ^ {\ast}) + k_1) \cdot G $$

これを$ r ^ \ast$としたら検証が通るのでは…?

上で書いた検証の式$ S \cdot G - Q + R = H(M) \cdot G$に

$$ S \leftarrow S_1 = d - k_1 $$

$$ R \leftarrow r ^ \ast = (H(M ^ {\ast}) + k_1) \cdot G $$

$$ M \leftarrow M ^ \ast $$

を代入してみます。

…OK!

以上をまとめると、

  • (適当に'ls'などを入力して)既知のメッセージ Mに対する署名$ (R_1, S_1)$を受け取る

  • $ \Delta = H(M ^ {\ast}) - H(M)$を計算する。

  • $ r ^ \ast = R_1 + \Delta \cdot G$を計算すると、$ (r ^ \ast, S_1)$が偽造署名になっている

となります。Pythonで書くとこんな感じ。

R_x, R_y, s = ...
cmd = 'ls'
cmd_forge = 'cat flag'
R = ECC.EccPoint(R_x, R_y, curve='P-256')
delta = (hash(cmd_forge) - hash(cmd)) % key._curve.order
R_star = R + pubkey._curve.G * delta

Writeup

# solver.py

from pwn import *
from Cryptodome.PublicKey import ECC
from Cryptodome.Hash import SHA3_256
from Cryptodome.Math.Numbers import Integer

# ハッシュ計算。配布コードと同じ
def hash(msg):
    h_obj = SHA3_256.new()
    h_obj.update(msg.encode())
    return Integer.from_bytes(h_obj.digest())

conn = remote('flu.xxx', 20085)

# P-256指定
key = ECC.generate(curve='P-256')

# 公開鍵d * Pの取得
conn.sendlineafter(b'>', b'show')
pub = conn.recvline()[:-1].decode().split(', ')
pub_x, pub_y = pub[1], pub[2]
pub_x = int(pub_x.split('=')[1])
pub_y = int(pub_y.split('=')[1][:-2])
pubkey = ECC.EccPoint(pub_x, pub_y, curve='P-256')

# 'ls'に対する署名(R_1, S_1)の取得
conn.recvline()
conn.sendlineafter(b'>', b'sign')
conn.sendlineafter(b'>', b'ls')
res = conn.recvline()[:-2].decode()
res = res.split(': ')[1].split('|')
assert len(res) == 4 and res[-1] == 'ls'
R_x, R_y, s = res[0], res[1], res[2]

cmd = 'ls'
cmd_forge = 'cat flag'
R = ECC.EccPoint(R_x, R_y, curve='P-256')
# R部分の偽造 (Sは使いまわし)
# delta = hash(cmd_forge) - hash(cmd)) % key._curve.order
R_star = R + pubkey._curve.G * ((hash(cmd_forge) - hash(cmd)) % key._curve.order)
R_star_x, R_star_y = str(R_star.x), str(R_star.y)

payload = ('|'.join([R_star_x, R_star_y, s, cmd_forge])).encode()
conn.recvline()
conn.sendlineafter(b'>', b'run')
conn.sendlineafter(b'>', payload)
print(conn.recvline())

conn.close()

flag{d1d_you_f1nd_chakraborty_mehta}

lwsr (20 solves, 3時間)

Sometimes you learn with errors, but I recently decided to learn with shift registers. Or did I learn with errors over shift registers? Shift registers over errors? Anyway, you may try to shift upwards on the investors board with this.

#!/usr/bin/env sage
from os import urandom
from sage.crypto.lwe import Regev
import sys

flag = b"flag{this_may_look_like_a_real_flag_but_its_not}"

def lfsr(state):
    # x^384 + x^8 + x^7 + x^6 + x^4 + x^3 + x^2 + x + 1
    mask   = (1 << 384) - (1 << 377) + 1
    newbit = bin(state & mask).count('1') & 1
    return (state >> 1) | (newbit << 383)

# LFSR initalization
state = int.from_bytes(urandom(384 // 8), "little")
assert state != 0

# Regev KeyGen
n = 128
m = 384

lwe = Regev(n)
q   = lwe.K.order()
pk  = [list(lwe()) for _ in range(m)]
sk  = lwe._LWE__s

# publish public key
print(f"Public key (q = {q}):")
print(pk)

# encrypt flag
print("Encrypting flag:")
for byte in flag:
    for bit in map(int, format(byte, '#010b')[2:]):
        # encode message
        msg = (q >> 1) * bit
        assert msg == 0 or msg == (q >> 1)

        # encrypt
        c = [vector([0 for _ in range(n)]), 0]
        for i in range(m):
            if (state >> i) & 1 == 1:
                c[0] += vector(pk[i][0])
                c[1] += pk[i][1]

        # fix ciphertext
        c[1] += msg
        print(c)

        # advance LFSR
        state = lfsr(state)

# clear LFSR bits
for _ in range(384):
    state = lfsr(state)

while True:
    # now it's your turn :)
    print("Your message bit: ")
    msg = int(sys.stdin.readline())
    if msg == -1:
        break
    assert msg == 0 or msg == 1

    # encode message
    pk[0][1] += (q >> 1) * msg

    # encrypt
    c = [vector([0 for _ in range(n)]), 0]
    for i in range(m):
        if (state >> i) & 1 == 1:
            c[0] += vector(pk[i][0])
            c[1] += pk[i][1]

    # fix public key
    pk[0][1] -= (q >> 1) * msg

    # check correctness by decrypting
    decrypt = ZZ(c[0].dot_product(sk) - c[1])
    if decrypt >= (q >> 1):
        decrypt -= q
    decode = 0 if abs(decrypt) < (q >> 2) else 1
    if decode == msg:
        print("Success!")
    else:
        print("Oh no :(")

    # advance LFSR
    state = lfsr(state)

こちらもまあまあ長いですね。というか説明がまず大変。

LSFR

線形回帰シフトレジスタ(LFSR)は特に詳しく説明しません。「384ビットのstateが更新されるときは右に1ビットシフトして、先頭(MSB)に1ビット付与」くらいの認識で構いません。厳密な説明をするにはnewbit変数を見る必要がありますが、stateの更新前後で(位置は違うものの)383ビットは使いまわしなので、残り1ビットの総当たりでも大丈夫です。

当然、更新後のstateから更新前に巻き戻すのも難しくないです。

# 配布コードと同じ
def lfsr(state:int):
    # x^384 + x^8 + x^7 + x^6 + x^4 + x^3 + x^2 + x + 1
    mask   = (1 << 384) - (1 << 377) + 1
    newbit = bin(state & mask).count('1') & 1
    return (state >> 1) | (newbit << 383)

# 巻き戻し。面倒だったので入力、出力はstr型。更新前stateのLSBを総当たり。
def lfsr_rewind(state:str):
    assert len(state) == 384
    before_state_0 = state[1:] + '0'
    before_state_1 = state[1:] + '1'
    if lfsr(int(before_state_0, 2)) == int(state, 2):
        return before_state_0
    else:
        return before_state_1

LWE

次にLWEのところですが、実はCakeCTFでも少し紹介していて、

 
\begin{pmatrix} 
m_{1} & \cdots & m_{N} 
\end{pmatrix}
\begin{pmatrix} 
X_{11} & \cdots & X_{1M} \\ 
\vdots & \ddots & \vdots \\ 
X_{N1} & \cdots & X_{NM} 
\end{pmatrix} 
+
\begin{pmatrix} 
e_{1} & \cdots & e_{M} 
\end{pmatrix}
 \equiv
\begin{pmatrix} 
Y_{1} & \cdots & Y_{M} 
\end{pmatrix}
 \ {\rm mod} \ q

という構成で、簡単に言うと$ N$変数の式が M個あり、かつ結果に少しノイズ$ e$が加わった連立方程式です。

このうち、 m秘密鍵(以下、$ sk$と表記)、行列$ X$および$ Y$が公開鍵に相当します。$ q$も(小文字表記していますが)既知です。

都合上この行列$ X$を$ (X_1, X_2, \cdots , X_M)$と表記することにします(当然、各要素はベクトルになります)。

同様に$ Y = (Y_1, Y_2, \cdots , Y_M)$とします(こちらは各要素は整数値)。

なお、今回は$ N = 128, M = 384, q = 16411$です。

ここまでがソースコード内の# Regev KeyGenの説明になります。

暗号化

flagの暗号化は以下のようにして行います。ここでは全て1-indexedとします。

  1. $ i=1$とし、stateをランダムな384ビット値で初期化

  2. $ C_{i} = $ (stateのLSBから数えて$ j$ビット目が1である$ j$に対して $X_j$の総和)

  3. $ C'_{i} = $ (stateのLSBから数えて$ j$ビット目が1である$ j$に対して $Y_j$の総和)

  4. flagのMSBから数えて$ i$ビット目が$1$なら$ C'_i$に$8205$を加算。この値は$ \lfloor \frac{q}{2} \rfloor$のこと

  5. $ C_i, C'_i$を表示

  6. $ i$に$1$を加算し、stateをLFSRで更新。

  7. $ i > 352$になったら(つまりflagのすべてのビットに対して暗号化が完了したら)終わり。そうでない場合はStep 2.に戻る。

Step 2.とStep 3.が分かりにくいと思うので補足しておきます。

例えば$ M = 5$でstateが$ 10010$の場合、Step 2.については$ C_i = X_2 + X_5$、Step 3.については$ C'_i = Y_2 + Y_5$となります。

ソースコードとの対応でいうと、$ X, Y$がpk[0]pk[1]に対応していて、$ C_i, C'_i$がc[0]c[1]に対応しています。

ちなみに、352という値はStep 5.の出力回数をカウントすればすぐにわかります。

このあとのstateを384回更新する操作($ \ast$)をもって、ソースコード内の# clear LFSR bitsまでが終了です。

復号オラクル?

以下の操作が無限回使用可能となっています。

  1. $0$か$1$を入力する。stateは($ \ast$)終了時の状態を引き継ぐ。

  2. $ C = $ (stateのLSBから数えて$ j$ビット目が1である$ j$に対して $X_j$の総和)

  3. $ C' = $ (stateのLSBから数えて$ j$ビット目が1である$ j$に対して $Y_j$の総和)

  4. Step 1.で入力した値が$1$かつ、stateのLSBが$ 1$なら$ C'$に$8205$を加算。

  5. $ d = sk \cdot C - C' \ {\rm mod} \ q$を計算する。$ sk$はLWEで作った秘密鍵

  6. $ d$が$4102$より大きい場合は$ m=1$、そうでない場合は$ m=0$とする。この値は$ \lfloor \frac{q}{4} \rfloor$のこと

  7. 入力した値と mが等しいかどうかを表示し、stateをLFSRで更新。

なぜ見出しが「復号オラクル?」になっているかは後述します(当然ながら、これは正しく復号できるアルゴリズムではないです)。

どこが怪しい…?

ぼんやりとソースコードを眺めていて、復号オラクルの部分のソースコード# encode message# fix public keyが気になりました。

暗号化の部分では# fix ciohertextの部分で直接c[1] += msgとして$ 0$または$ 8205$を足しているのに、

復号オラクルでは何故かpk[0][1] += (q >> 1) * msgと公開鍵部分を経由して$ 0$または$ 8205$を足しています。

これが今回の脆弱性で、上記の暗号化と復号オラクルのStep 4.の挙動が(いかにも意味深な感じで)異なっています。

もう少し詳しく言うと、

  • 暗号化:入力(flagのビット)が$1$なら$ C'$に$8205$を加算、それ以外は加算しない。

  • 復号オラクル:入力が$1$かつ、stateのLSBが$ 1$なら$ C'$に$8205$を加算、それ以外は加算しない。

となっています。

なので、この2つを見比べると「入力が$ 1$かつ、stateのLSBが$ 0$」である場合に挙動が変わってきます(このケースに限って、復号がうまくいかないです)。

まとめると、「入力に$ 1$を入れて復号が成功するなら、その時点でのstateのLSBは$ 1$。失敗するならLSBは$ 0$」となります。

stateの復元

さて、最初の方で述べた通り、LSFRは状態を更新するときに右に1ビットシフトするので、

「更新前のLSBから数えて$ i+1$ビット目」と「更新後のLSBから数えて$ i$ビット目」は等しいです。

なので、先ほどの復号オラクルの問い合わせで、例えば

  • 1回目に1を入力したら成功した $ \rightarrow$ ($ \ast$)時点でのstateのLSBは1

  • 2回目(state1回更新時)に1を入力したら失敗した $ \rightarrow$ ($ \ast$)から1回更新した時点でのstateのLSBは0 $ \rightarrow$ ($ \ast$)時点でのstateのLSBから数えて$ 2$ビット目は0

...

といった感じで、384回1を入力した結果を使って($ \ast$)時点でのstate状態を完全に復元できます。

せっかくなので、一番最初に書いたlfsr_rewind()を使って一番最初のstateまで戻してしまいましょう。

def send_bit(i:int):
    conn.recvline()
    conn.sendline(str(i).encode())
    res = conn.recvline()
    #print(res)
    if b'Success' in res:
        return '1'
    else:
        return '0'

# (*)時点でのstate
mid_state = ''
for i in range(384):
    mid_state = send_bit(1) + mid_state

# 配布コードの"clear LFSR bits"直前のstate
for i in range(384):
    mid_state = lfsr_rewind(mid_state)

# 初期state
for i in range(352):
    mid_state = lfsr_rewind(mid_state)

flagの復元

ここから先、LWEを使っているので格子を使って秘密鍵を求めるのかと思うのが至極真っ当な思考回路ですが、

それならそもそも公開鍵が与えられた時点で秘密鍵も求められるはずであり、わざわざ復号オラクルなんぞ用意する必要はありません。(小糸は訝しんだ)

ということで、改めて暗号化の部分に戻ってみるわけですが、

  1. $ i=1$とし、stateをランダムな384ビット値で初期化

  2. $ C_{i} = $ (stateのLSBから数えて$ j$ビット目が1である$ j$に対して $X_j$の総和)

  3. $ C'_{i} = $ (stateのLSBから数えて$ j$ビット目が1である$ j$に対して $Y_j$の総和)

  4. flagのMSBから数えて$ i$ビット目が$1$なら$ C'_i$に$8205$を加算。この値は$ \lfloor \frac{q}{2} \rfloor$のこと

  5. $ C_i, C'_i$を表示

  6. $ i$に$1$を加算し、stateをLFSRで更新。

  7. $ i > 352$になったら(つまりflagのすべてのビットに対して暗号化が完了したら)終わり。そうでない場合はStep 2.に戻る。

という流れでした。実は、stateがもう完全に掌握できていて、かつ公開鍵$Y$も与えられているので、Step 3.時点での$ C'_i$の値は分かるわけです。

なので、これとStep 5.での出力を比較して、一致しているならflagの$ i$ビット目は$ 0$であり、$ 8205$の差分が生じているなら$ i$ビット目は$ 1$と判定できます。

各ビットごとでの判定は以下のようにコードを組めばいいです。

given_c_lsがStep 5での  C' _ {i} の出力、c_modが復元したstateと公開鍵をもとに計算したStep 3. 時点での C'_iの値です。

# 配布コードと同じ
c = [vector([0 for _ in range(128)]), 0]
for i in range(384):
    if (state >> i) & 1 == 1:
        c[0] += vector(pk[i][0])
        c[1] += pk[i][1]
    
c_mod = [int(_ % q) for _ in c[0]]
if given_c_ls[j][1] == c[1] % q:
    print('0')
else:
    print('1')

# state更新。忘れずに
state = lfsr(state)

Writeup

本当は復号オラクルの正当性についても述べたかったのですが、これ以上はてブで数式を書きたくなくてモチベ維持ができそうになかったので断念します。

めちゃくちゃ要望があれば書くかもしれません。

# solver.sage

from pwn import *
from Crypto.Util.number import *

conn = remote('flu.xxx', 20075)

# 復号オラクルアクセス、stateの復元
def send_bit(i:int):
    conn.recvline()
    conn.sendline(str(i).encode())
    res = conn.recvline()
    if b'Success' in res:
        return '1'
    elif b'Oh no :(' in res:
        return '0'
    # 例外処理
    else:
        print('[+] coding error')
        assert 1 == 2
        return -1

# LFSR。配布コードと同じ
def lfsr(state:int):
    # x^384 + x^8 + x^7 + x^6 + x^4 + x^3 + x^2 + x + 1
    mask   = (1 << 384) - (1 << 377) + 1
    newbit = bin(state & mask).count('1') & 1
    return (state >> 1) | (newbit << 383)

# LFSR巻き戻し
def lfsr_rewind(state:str):
    assert len(state) == 384
    before_state_0 = state[1:] + '0'
    before_state_1 = state[1:] + '1'
    if lfsr(int(before_state_0, 2)) == int(state, 2):
        return before_state_0
    elif lfsr(int(before_state_1, 2)) == int(state, 2):
        return before_state_1
    # 例外処理
    else:
        print('[+] coding error')
        assert 1 == 2
        return -1

conn.recvuntil(b'= ')
# qの値取得
q = int(conn.recvline()[:-3])
assert q == 16411

# 公開鍵取得
pk = eval(conn.recvline()[:-1].decode())
assert len(pk) == 384 and len(pk[0][0]) == 128

# 暗号化の部分の出力取得
conn.recvline()
given_c_ls = []
# flag : 352 bits
for i in range(352):
    c = eval(conn.recvline()[:-1].decode())
    given_c_ls.append(c)

# 復号オラクルアクセス、(*)時点でのstate
mid_state = ''
for i in range(384):
    print('{}/384 state collection done'.format(i))
    mid_state = send_bit(1) + mid_state

conn.close()

# 配布コードの"clear LFSR bits"直前のstate
for i in range(384):
    mid_state = lfsr_rewind(mid_state)

# 初期state
for i in range(352):
    mid_state = lfsr_rewind(mid_state)
state = int(mid_state, 2)

flag = ''
# 352ビットのflag計算
for j in range(352):
    c = [vector([0 for _ in range(128)]), 0]
    for i in range(384):
        if (state >> i) & 1 == 1:
            c[0] += vector(pk[i][0])
            c[1] += pk[i][1]
        
    c_mod = [int(_ % q) for _ in c[0]]
    # 念のため、暗号文として与えられるC_iと自分でstateから計算できるStep 2.のC_iが一致するか確認
    check = [int(x) == int(y) for x, y in zip(c_mod, given_c_ls[j][0])]
    assert all(check)

    if given_c_ls[j][1] == c[1] % q:
        print('0')
        flag += '0'
    elif given_c_ls[j][1]  == (c[1] + (q >> 1)) % q:
        print('1')
        flag += '1'
    # 例外処理
    else:
        print('[+] coding error')
        print('[+] given val : ', given_c_ls[j][1])
        print('[+] calc myself : ', c[1] % q, ' or ', (c[1] + (q >> 1)) % q) 
        assert 1 == 2
    # state更新
    state = lfsr(state)
    print('{}/352 flag found'.format(j))
    
flag = int(flag, 2)
print(long_to_bytes(flag))

flag{your_fluxmarket_stock_may_shift_up_now}