MSS + MSS Revenge

The first challenge I started with was MSS from the Crypto category. By writing out the math on some paper I understood the challenge better and tested an idea that would allow me to decrypt the flag with an input the developer had not thought of. This turned out to be an unintended solution, and around a day later another challenge popped up called "MSS Revenge", and sure enough, the unintended solution was patched with a single if statement.
After having gone through the math, however, and a hint from the flag text in the first challenge, I eventually also figured out the intended solution with a very satisfying solution.

The Challenge

For this challenge, a server.py script was given as well as a host and port where this server was listening. The main() function contains a simple command-line interface with a few commands:

Python

def main():
    ...
    query = json.loads(input(show_menu()))
    if 'command' in query:
        cmd = query['command']
        if cmd == 'get_share':
            if 'x' in query:
                x = int(query['x'])
                share = mss.get_share(x)
                print(json.dumps(share))
            else:
                print('\n[-] Please send your user ID.')
        elif cmd == 'encrypt_flag':
            enc_flag = mss.encrypt_flag(FLAG)
            print(f'\n[+] Here is your encrypted flag : {json.dumps(enc_flag)}.')
        elif cmd == 'exit':
            print('\n[+] Thank you for using our service. Bye! :)')
            break
        else:
            print('\n[-] Unknown command:(')

All input is read as a JSON object. First is a get_share command that takes one parameter 'x' and returns the output of mss.get_share(x). This mss variable is defined before, coming from the class at the top of the file:

Python

class MSS:
    def __init__(self, BITS, d, n):
        self.d = d
        self.n = n
        self.BITS = BITS
        self.key = bytes_to_long(os.urandom(BITS//8))
        self.coeffs = [self.key] + [bytes_to_long(os.urandom(self.BITS//8)) for _ in range(self.d)]

    def poly(self, x):
        return sum([self.coeffs[i] * x**i for i in range(self.d+1)])

    def get_share(self, x):
        if x > 2**15:
            return {'approved': 'False', 'reason': 'This scheme is intended for less users.'}
        elif self.n < 1:
            return {'approved': 'False', 'reason': 'Enough shares for today.'}
        else:
            self.n -= 1
            return {'approved': 'True', 'x': x, 'y': self.poly(x)}
    
    def encrypt_flag(self, m):
        key = sha256(str(self.key).encode()).digest()
        iv = os.urandom(16)
        cipher = AES.new(key, AES.MODE_CBC, iv)
        ct = cipher.encrypt(pad(m, 16))
        return {'iv': iv.hex(), 'enc_flag': ct.hex()}


def main():
    mss = MSS(256, 30, 19)
    ...

Some seemingly random parameters are given to the __init__() function, which generates a set of random numbers in self.ooeffs. The first of which is special as it is stored as self.key too, and used as an AES encryption key in the encrypt_flag() function. We'll need to find this key to decrypt the flag we get after sending the encrypt_flag command in the CLI.

The get_share() is especially interesting to us, because it is our input and output. Some checks are in place that make sure you cannot use the function more than 19 times, and that the input is not very large (max 2¹⁵ = 32768). Otherwise, it is passed to the poly(x) function which contains an interesting equation.

sum([self.coeffs[i] * x**i for i in range(self.d+1)]) means that it iterates for a length of 31 times, takes the coefficients that contains the key, and for each one, it multiplies it with our input x to the power of the iterator i. Then these are all summed up and returned.

So in the end, we need to find some way of recovering self.key, the first coefficient. Then we will request and decrypt the flag with it.

The Math

Let's play around with the math for a bit, to see if we can create anything interesting. A first good step is to think of edge cases for inputs that create interesting results. A good first candidate is x=1, let's see what that would result in:

 

poly(1) = c_0*1^0 + c_1*1^1 + ... c_30*1^30
        = c_0     + c_1     + ... c_30

That just sums up all coefficients. Let's try another one, x=0:

 

poly(0) = c_0*0^0 + c_1*0^1 + ... c_30*0^30
        = 0

That's less interesting, the result is always 0. Also interesting is when we go up by a single number, to x=2:

 

poly(1) = c_0*2^0 + c_1*2^1 + c_1*2^2 + ... c_30*2^30
        = c_0*1   + c_1*2   + c_2*4   + ... c_30*1073741824

Those numbers get big quite quickly! Remember that the key is c_0 in this notation, it is completely overshadowed by the bigger coefficients that come after it. But there is one more interesting input to try when we look at the code, negative numbers like x=-1:

 

poly(-1) = c_0*-1^0 + c_1*-1^1 + c_1*-1^2 + ... c_30*-1^30
         = c_0      - c_1      + c_2      - ... c_30

This alternating pattern makes me think of canceling some terms, but just the c_0 still seems hard to extract from the output because all other random coefficients are added or subtracted from it. Let's take one more look at the source code:

Python

def get_share(self, x):
    if x > 2**15:
        ...

Hold on, we just input a negative number and it worked... There is only a check for if a number is large and positive, but no such check if the number is very large and negative! That's another interesting edge case, we'll put in a giant negative number because we can like -10¹⁰⁰ (A Googol):

 

poly(-1) = c_0*-100...000^0 + c_1*-100...000^1 + c_1*-100...000^2 + ... c_30*-100...000^30
         = c_0              - c_1*-100...000   + c_2*-100....000  - ... c_30*-100.....000

These numbers at the end will get ginormous, 100*30 = 3000 digits! Python will handle it fine though because computers are very good at math. If we take a good look however we might notice that these numbers get so far apart from each other that they don't interfere with each other's digits anymore. If we hardcode a self.key value like 1234567890...1234567 the result will look like this:

 

242455279049161590705046742997867567062104714784870962098114375255984579239699999999999999999999999895681511444545085880991
303002744684803589968807038845812869137073193845477199000000000000000000000003446343133726784127776804771223737235828138327
532668347523217407951644619056799999999999999999999999319261630826221433788143135623238334112293459076337689849258842864883
...
000746791943369099523082822504705797111028308332063739267731687451571235118105479999999999999999999999942322767840374745438
494991779878216793802623771292301963642565572085835956817000000000000000000000006861930715276834002794744515546738328968605
577237446074244134329211423173096499999999999999999999999064679495894165997726039372243623958600883665524507648766145180014
0115111639200000000000000000000000*12345678901234567890123456789012345678901234567890123456789012345678901234567*

That's the key we set right there in our terminal! By using a large negative number it bypasses the check, and allows us to differentiate all the different coefficients. Now we just need to run it on the real server and use it to decrypt the flag.

One-shot solution script

Queries are made in a JSON format. We just need to send the -10¹⁰⁰ value like this:

JSON

{"command": "get_share", "x": -10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000}

From the response we can get the key, as shown above, so then we ask for the encrypted flag:

JSON

{"command": "encrypt_flag"}

The encryption itself is then done by just copying the original encrypt_flag() function, and changing an "encrypt" to a "decrypt", and a "pad" to an "unpad". Finally, we get a script like this, using pwntools to interact with the server:

Python

from pwn import *
from hashlib import sha256
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
import json

r = remote("94.237.51.68", 45769)


def query(command, data={}):
    data["command"] = command
    r.sendlineafter(b"query = ", json.dumps(data).encode())
    return json.loads(r.recvline())


X = -10**100

result = query('get_share', {'x': X})

# Coefficients don't overlap anymore, so just take a mod to get the lowest (first) one
key_n = result['y'] % X
# Decryption logic using extracted key
key = sha256(str(key_n).encode()).digest()
success(f"Key: {key_n} -> {key}")

r.sendlineafter(b"query = ", json.dumps({'command': 'encrypt_flag'}).encode())
r.recvuntil(b"flag : ")
result = json.loads(r.recvuntil(b'.', drop=True))

iv             = bytes.fromhex(result['iv'])
encrypted_flag = bytes.fromhex(result['enc_flag'])
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
flag = cipher.decrypt(encrypted_flag)

print(unpad(flag, AES.block_size))

When we run this script, it gets the 3000-digit number and extracts the key from it, decrypting the flag!
HTB{thr3sh0ld_t00_sm4ll_______CRT_t00_str0nk!}

... 'threshold too small', 'CRT too strong'?!? What's that all about?

MSS Revenge

'CRT' stands for the Chinese Remainder Theorem which is the intended solution to this challenge. I quickly realized I had found an unintended solution, and did not bother trying to find the intended solution as I had other challenges to do. But about a day later I noticed another challenge popped up called "MSS Revenge", and sure enough, it fixed the unintended solution we used as its only difference:

Diff

def get_share(self, x):
-   if x > 2**15:
+   if x < 1 or x > 2**15:

Okay, okay, I'll play your game. At least we know the solution now has something to do with the CRT. First, what does that even mean?

Wikipedia tells us it is a way of finding a number by having a few samples of remainders of that number, and my book shows a more practical example:

Given some moduli n_1, n_2, n_3, ..., n_k and some samples c_1, c_2, c_3, ..., c_k of the unknown x mod these n's, we can find x efficiently

 

c_1 = x mod n_1
c_2 = x mod n_2
c_3 = x mod n_3
...
c_k = x mod n_k

This is interesting because it does not seem trivial to recover x from just the c's and n's. We just have to apply it to our problem in some useful way.

Applying the CRT

One value we want to recover is the key, the first coefficient. We would somehow need to find multiple modulo results of this number. The challenge does not use a modulo operation anywhere so this is not directly obvious.
We can't just go around applying modulo's on random responses from the server, because remember the other coefficients are still in the way. If we were to somehow be able to get rid of these and get a modulo of only that first coefficient, CRT could be used to recover it.

One neat property of the modulo operation is that it can simplify equations with multiplication, because if we choose the modulus to be the multiplier it will always cancel out to 0, and we can then ignore that term. For example:

(a + b*42) % 42 = a

If only there was a multiple of every number except for the first coefficient. Wait. That's exactly what the powers in the poly() equation do for us!
When we put in a number like x=2, every coefficient will be a multiple of 2 except for the first because 2⁰ = 1. When we mod the final result by 2 then, the result will have canceled all other terms and we are left with the modulo of the key. Doing this multiple times gives us enough samples for the CRT.

All moduli should be 'coprime' with each other, meaning they don't share any factors. An easy way to do this is by just generating prime numbers that are as large as possible while staying below the limit:

Python

from Crypto.Util.number import isPrime

# Start from the max, and go down. Only continue if n is prime
for n in range(2**15, 1, -1):
    if not isPrime(n):
        continue

    result = query('get_share', {'x': n})

We will save these n's and c's in a list and apply the sympy.ntheory.modular.crt() to it to recover the key:

Python

from sympy.ntheory.modular import crt

moduli = []
results = []
for n in range(2**15, 1, -1):
    if not isPrime(n):
        continue

    result = query('get_share', {'x': n})
    if result['approved'] == 'False':
        break

    moduli.append(n)
    results.append(result['y'])

key_n = crt(moduli, results)[0]

Finally, the same code from before can be used to request and decrypt the flag from the server:

Python

...
key = sha256(str(key_n).encode()).digest()
success(f"Key: {key_n} -> {key}")

r.sendlineafter(b"query = ", json.dumps({'command': 'encrypt_flag'}).encode())
r.recvuntil(b"flag : ")
result = json.loads(r.recvuntil(b'.', drop=True))

iv             = bytes.fromhex(result['iv'])
encrypted_flag = bytes.fromhex(result['enc_flag'])
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
flag = cipher.decrypt(encrypted_flag)

print(unpad(flag, AES.block_size))