Cracking Python Random Module

by

in

We often see Python random modules in CTF and other open sources. It is very useful if used well, but at the same time it has serious vulnerabilities (such as predicting random value). Because of the algorithm named Mersenne Twister used in python random module.

Before we begin, I would like to inform you that this article is written by looking at rbtree’s python random module blog, and some codes are quoted. In addition, it is very similar to the solution to the problem of Unbreakable on the dreamhack site, so please pay attention to the spoiler.

Mersenne Twister Algorithm

First, in order to understand the vulnerability of the Mersenne Twister, we need to look at the C code of this algorithm. The code is as follows. A total of 624 arrays are created to produce random values through repetitive statements.

/* Period parameters -- These are all magic.  Don't change. */
#define N 624
#define M 397
#define MATRIX_A 0x9908b0dfU    /* constant vector a */
#define UPPER_MASK 0x80000000U  /* most significant w-r bits */
#define LOWER_MASK 0x7fffffffU  /* least significant r bits */
static uint32_t

genrand_uint32(RandomObject *self)
{
    uint32_t y;
    static const uint32_t mag01[2] = {0x0U, MATRIX_A};
    /* mag01[x] = x * MATRIX_A  for x=0,1 */
    uint32_t *mt;
    mt = self->state;
    
    if (self->index >= N) { /* generate N words at one time */
        int kk;
        
        for (kk=0;kk<N-M;kk++) {
            y = (mt[kk]&UPPER_MASK)|(mt[kk+1]&LOWER_MASK);
            mt[kk] = mt[kk+M] ^ (y >> 1) ^ mag01[y & 0x1U];
        }
        
        for (;kk<N-1;kk++) {
            y = (mt[kk]&UPPER_MASK)|(mt[kk+1]&LOWER_MASK);
            mt[kk] = mt[kk+(M-N)] ^ (y >> 1) ^ mag01[y & 0x1U];
        }
        
        y = (mt[N-1]&UPPER_MASK)|(mt[0]&LOWER_MASK);
        mt[N-1] = mt[M-1] ^ (y >> 1) ^ mag01[y & 0x1U];
        self->index = 0;
    }
    
    y = mt[self->index++];
    y ^= (y >> 11);
    y ^= (y << 7) & 0x9d2c5680U;
    y ^= (y << 15) & 0xefc60000U;
    y ^= (y >> 18);
    
    return y;
}
C

As we can see in the code, an array of 624 lengths is created and outputted through a series of shift processes. And if the number of calls exceeds 624, the size of the array, a new stream is created similar to LFSR through the entire array. We can make this LFSR-similar logic concisely in python like this. XD

N = 624 # size of array
M = 397

MATRIX_A = 0x9908b0df; mag01 = [0, MATRIX_A]
UPPER_MASK = 0x80000000 # only msb
LOWER_MASK = 0x7fffffff # whole bits except msb

if index >= N:
    for kk in range(N):
    
        y = (mt[kk]&UPPER_MASK)|(mt[kk+1]&LOWER_MASK);
        mt[kk] = mt[(kk+M) % N] ^ (y >> 1) ^ mag01[y & 0x1U];
        
    index = 0
Python

As you all know, the xor operation is a reversible operation, unlike the and operation. So if all the computational processes that generate the next state are composed of only xor operations, inverse computation will be sufficiently possible.

I know what you are thinking. There is an and operation on the generator method. In conclusion, the and operation in the mag01[y & 0x1U] could be converted into multiple operation.

Let’s express the equation mag01[y & 1]. It results the MATRIX_A value if y is odd, and if y is even, the value of this equation is 0. so we can think the result $ y \times \text{MATRIX_A}$. So, we know we could reverse-calculate the stream-generator equation.

Temper & Untemper

Now, our aim is to reverse a process called temper (The last process). Let’s see the temper process.

Actually, there is so many reverse-temper (untemper) process in internet. So we don’t have to consider about the process, but it seems like I’m posting without effort, so let’s just take a quick look.

# temper.py
y = mt[self->index++];
y ^= (y >> 11);
y ^= (y << 7) & 0x9d2c5680;
y ^= (y << 15) & 0xefc60000;
y ^= (y >> 18);
Python
# untemper.py
TemperingMaskB = 0x9d2c5680
TemperingMaskC = 0xefc60000

def untemper(y):
    y = undoTemperShiftL(y)
    y = undoTemperShiftT(y)
    y = undoTemperShiftS(y)
    y = undoTemperShiftU(y)
    return y
    
def undoTemperShiftL(y):
    last14 = y >> 18
    final = y ^ last14
    return final
    
def undoTemperShiftT(y):
    first17 = y << 15
    final = y ^ (first17 & TemperingMaskC)
    return final
    
def undoTemperShiftS(y):
    a = y << 7
    b = y ^ (a & TemperingMaskB)
    c = b << 7
    d = y ^ (c & TemperingMaskB)
    e = d << 7
    f = y ^ (e & TemperingMaskB)
    g = f << 7
    h = y ^ (g & TemperingMaskB)
    i = h << 7
    final = y ^ (i & TemperingMaskB)
    return final
    
def undoTemperShiftU(y):
    a = y >> 11
    b = y ^ a
    c = b >> 11
    final = y ^ c
    return final
Python

It’s very ambiguous to explain with my shallow explanatory power, but considering the bits that are preserved, some bits can be recovered, so it’s untemperable… lol 🙂

Anyway, we realized that the untemper is also possible, and we can see that we can “theoretically” recover all the bits.

Implementation

With the untemper method, if we know results of 624 getrandbits(32), we could recover initial states.

import random

state = random.getstate()
outputs = [ random.getrandbits(32) for _ in range(1000) ]

recovered_state = (3, tuple([ untemper(v) for v in outputs[:624] ] + [0]), None)
random.setstate(recovered_state)

for i in range(1000):
    assert outputs[i] == random.getrandbits(32)
Python

But, I’m still hungry. If we can’t figure out the entire 32 bits, so I can only know some bits, is there any way to recover the state? In conclusion, we can.

Our idea is to make a system of equations made of bits. If we look at all the bits used in the Mersenne Twister as a variable, we can make system of equations made of $624 \times 32 = 19968$ variables. so the computer could compute all the stuffs. The rbtree made it with python bit xor with his great brain lol, so I will quote it… XD (Thx for rbtree..)

class Twister:
    N = 624
    M = 397
    A = 0x9908b0df
    
    def __init__(self):
        self.state = [ [ (1 << (32 * i + (31 - j))) for j in range(32) ] for i in range(self.N)]
        self.index = 0
    
    @staticmethod
    def _xor(a, b):
        return [x ^ y for x, y in zip(a, b)]
    
    @staticmethod
    def _and(a, x):
        return [ v if (x >> (31 - i)) & 1 else 0 for i, v in enumerate(a) ]
    
    @staticmethod
    def _shiftr(a, x):
        return [0] * x + a[:-x]
    
    @staticmethod
    def _shiftl(a, x):
        return a[x:] + [0] * x
        
    def get32bits(self):
        if self.index >= self.N:
            for kk in range(self.N):
                y = self.state[kk][:1] + self.state[(kk + 1) % self.N][1:]
                z = [ y[-1] if (self.A >> (31 - i)) & 1 else 0 for i in range(32) ]
                self.state[kk] = self._xor(self.state[(kk + self.M) % self.N], self._shiftr(y, 1))
                self.state[kk] = self._xor(self.state[kk], z)
            self.index = 0
        y = self.state[self.index]
        y = self._xor(y, self._shiftr(y, 11))
        y = self._xor(y, self._and(self._shiftl(y, 7), 0x9d2c5680))
        y = self._xor(y, self._and(self._shiftl(y, 15), 0xefc60000))
        y = self._xor(y, self._shiftr(y, 18))
        self.index += 1
        return y
    
    def getrandbits(self, bit):
        return self.get32bits()[:bit]
Python

Just in quick look, the algorithm is same with python random module. But our aim is to make the bits into variables respectively, so rbtree implemented it with making $1 << n$. So that each bits can move respectively. It is now possible to know which bits of the mersenne twister are affected by each result value of the random module.

Now with the equations and real outputs, we can reduce the equations with removing the lsbs. I couldn’t think of this way with my shallow knowledge, but I think it’s a really great way of thinking. Take a look.

class Solver:
    def __init__(self):
        self.equations = []
        self.outputs = []
    
    def insert(self, equation, output):
        for eq, o in zip(self.equations, self.outputs):
            lsb = eq & -eq
            if equation & lsb:
                equation ^= eq
                output ^= o
        
        if equation == 0:
            return
        lsb = equation & -equation
        for i in range(len(self.equations)):
            if self.equations[i] & lsb:
                self.equations[i] ^= equation
                self.outputs[i] ^= output
    
        self.equations.append(equation)
        self.outputs.append(output)
    
    def is_solvable(self):
        print(len(self.equations))
        return len(self.equations) == 624 * 32
    
    def solve(self):
        if not self.is_solvable():
            assert False, "Not solvable"
        
        num = 0
        for i, eq in enumerate(self.equations):
            assert eq == (eq & -eq), "Should be reduced now"
            if self.outputs[i]:
                num |= eq
        
        state = [ (num >> (32 * i)) & 0xFFFFFFFF for i in range(624) ][::-1]
        return state
Python

so we can recover all the states !! (If we have enough equations)

Result

The final code is as follows.

class Twister:
    N = 624
    M = 397
    A = 0x9908b0df
    def __init__(self):
        self.state = [ [ (1 << (32 * i + (31 - j))) for j in range(32) ] for i in range(624)]
        self.index = 0
    
    @staticmethod
    def _xor(a, b):
        return [x ^ y for x, y in zip(a, b)]
    
    @staticmethod
    def _and(a, x):
        return [ v if (x >> (31 - i)) & 1 else 0 for i, v in enumerate(a) ]
    
    @staticmethod
    def _shiftr(a, x):
        return [0] * x + a[:-x]
    
    @staticmethod
    def _shiftl(a, x):
        return a[x:] + [0] * x
        
    def get32bits(self):
        if self.index >= self.N:
            for kk in range(self.N):
                y = self.state[kk][:1] + self.state[(kk + 1) % self.N][1:]
                z = [ y[-1] if (self.A >> (31 - i)) & 1 else 0 for i in range(32) ]
                self.state[kk] = self._xor(self.state[(kk + self.M) % self.N], self._shiftr(y, 1))
                self.state[kk] = self._xor(self.state[kk], z)
            self.index = 0
        y = self.state[self.index]
        y = self._xor(y, self._shiftr(y, 11))
        y = self._xor(y, self._and(self._shiftl(y, 7), 0x9d2c5680))
        y = self._xor(y, self._and(self._shiftl(y, 15), 0xefc60000))
        y = self._xor(y, self._shiftr(y, 18))
        self.index += 1
        return y
    
    def getrandbits(self, bit):
        return self.get32bits()[:bit]

class Solver:
    def __init__(self):
        self.equations = []
        self.outputs = []
    
    def insert(self, equation, output):
        for eq, o in zip(self.equations, self.outputs):
            lsb = eq & -eq
            if equation & lsb:
                equation ^= eq
                output ^= o
        
        if equation == 0:
            return
            
        lsb = equation & -equation
        for i in range(len(self.equations)):
            if self.equations[i] & lsb:
                self.equations[i] ^= equation
                self.outputs[i] ^= output
    
        self.equations.append(equation)
        self.outputs.append(output)
    
    def solve(self):
        num = 0
        for i, eq in enumerate(self.equations):
            if self.outputs[i]:
                # Assume every free variable is 0
                num |= eq & -eq
        
        state = [ (num >> (32 * i)) & 0xFFFFFFFF for i in range(624) ]
        return state
        
import random
num = 1247
bit = 30

twister = Twister()
outputs = [ random.getrandbits(bit) for _ in range(num) ]
equations = [ twister.getrandbits(bit) for _ in range(num) ]
solver = Solver()

for i in range(num):
    for j in range(bit):
        print(i, j)
        solver.insert(equations[i][j], (outputs[i] >> (bit - 1 - j)) & 1)
        
state = solver.solve()
recovered_state = (3, tuple(state + [0]), None)
random.setstate(recovered_state)

for i in range(num):
    assert outputs[i] == random.getrandbits(bit)
Python