Stable Diffusion from the ground up

Here we shall build a working Stable Diffusion model using just Python, the Python standard library, Matplotlib for plots, Jupyter Notebook, which is where we shall be writing our code, and nbdev, which we can use to create modules from notebooks.

from pathlib import Path
import pickle, gzip, os, math, time, shutil , matplotlib as mpl, matplotlib.pyplot as plt
Get data
MNIST_URL = 'https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'
from urllib.request import urlretrieve
urlretrieve
<function urllib.request.urlretrieve(url, filename=None, reporthook=None, data=None)>
if not path_gz.exists(): urlretrieve(MNIST_URL,path_gz)
!ls -l data
total 16656
-rw-r--r-- 1 rubanza rubanza 17051982 Jul 21 16:23 mnist.pkl.gz

-rw-r–r– 1 user staff 15296311 Jul 21 10:30 mnist.pkl.gz | | | | | | | | | | | | | └── Filename | | | | | └── Modification time | | | | └── File size (bytes) | | | └── Group owner | | └── User owner
| └── Number of hard links └── File permissions

path_gz
PosixPath('data/mnist.pkl.gz')
#with gzip.open('data/mnist.pkl.gz','rb') as f:
    #dataset = pickle.load(f, encoding='latin-1')
dataset
((array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  array([5, 0, 4, ..., 8, 4, 8])),
 (array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  array([3, 8, 6, ..., 5, 6, 8])),
 (array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  array([7, 2, 1, ..., 4, 5, 6])))
with gzip.open(path_gz, 'rb') as f: ((x_train,y_train),(x_valid,y_valid),_) = pickle.load(f, encoding='latin-1')
x_train.shape,y_train.shape,x_valid.shape,y_valid.shape
((50000, 784), (50000,), (10000, 784), (10000,))
x_train
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)
lst1 = list(x_train[0])
vals = lst1[200:210]
vals
[0.0,
 0.0,
 0.0,
 0.19140625,
 0.9296875,
 0.98828125,
 0.98828125,
 0.98828125,
 0.98828125,
 0.98828125]
len(lst1)
784
len(lst1)
784
def chunks(x,sz):
    for i in range(0, len(x), sz): yield x[i:i+5]
list(chunks(vals,5))
[[0.0, 0.0, 0.0, 0.19140625, 0.9296875],
 [0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.98828125]]
val_iter = chunks(vals,5)
val_iter
<generator object chunks at 0x715981808580>
next(val_iter)
[0.0, 0.0, 0.0, 0.19140625, 0.9296875]
next(val_iter)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[78], line 1
----> 1 next(val_iter)

StopIteration: 
mpl.rcParams['image.cmap']='gray'
plt.imshow(list(chunks(lst1,28)));

plt.rcParams['image.cmap'] = 'gray'
plt.imshow(list(chunks(lst1, 28)))
plt.xlim(0, 28)  # Force x-axis to go 0-28
plt.ylim(0, 28)  # Force y-axis to go 0-28

a = [1,2,3,4,5]
len(a)
5
for i in range(0,5): 
    print(i+1)
1
2
3
4
5
from itertools import islice
it = iter(vals)
vals
[0.0,
 0.0,
 0.0,
 0.19140625,
 0.9296875,
 0.98828125,
 0.98828125,
 0.98828125,
 0.98828125,
 0.98828125]
it
<list_iterator at 0x715981848730>
next(it)
0.0
next(it),next(it),next(it)
(0.0, 0.0, 0.19140625)

islice

it = iter(vals)
islice(it,28)
<itertools.islice at 0x76d844786660>
isit = islice(it,5)
isit
<itertools.islice at 0x7159819054e0>
next(isit),next(isit),next(isit)
(0.0, 0.0, 0.0)
next(isit)
0.19140625
next(isit)
0.9296875
next(isit)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[140], line 1
----> 1 next(isit)

StopIteration: 
next(isit)
0.19140625
list(islice(it,5))
[0.98828125, 0.98828125, 0.98828125]
list(islice(it,28))
[0.0,
 0.0,
 0.0,
 0.19140625,
 0.9296875,
 0.98828125,
 0.98828125,
 0.98828125,
 0.98828125,
 0.98828125]
list(islice(it,28))
[]
next(it)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[122], line 1
----> 1 next(it)

StopIteration: 
islice(it,2)
<itertools.islice at 0x715981af47c0>
list(islice(it,5))
[]
list(islice(it,5))
[]

using lamda

it = iter(lst1)
img = list(iter(lambda: list(islice(it, 28)), []))
img
[[0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.01171875,
  0.0703125,
  0.0703125,
  0.0703125,
  0.4921875,
  0.53125,
  0.68359375,
  0.1015625,
  0.6484375,
  0.99609375,
  0.96484375,
  0.49609375,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.1171875,
  0.140625,
  0.3671875,
  0.6015625,
  0.6640625,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.87890625,
  0.671875,
  0.98828125,
  0.9453125,
  0.76171875,
  0.25,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.19140625,
  0.9296875,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98046875,
  0.36328125,
  0.3203125,
  0.3203125,
  0.21875,
  0.15234375,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0703125,
  0.85546875,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.7734375,
  0.7109375,
  0.96484375,
  0.94140625,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.3125,
  0.609375,
  0.41796875,
  0.98828125,
  0.98828125,
  0.80078125,
  0.04296875,
  0.0,
  0.16796875,
  0.6015625,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0546875,
  0.00390625,
  0.6015625,
  0.98828125,
  0.3515625,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.54296875,
  0.98828125,
  0.7421875,
  0.0078125,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.04296875,
  0.7421875,
  0.98828125,
  0.2734375,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.13671875,
  0.94140625,
  0.87890625,
  0.625,
  0.421875,
  0.00390625,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.31640625,
  0.9375,
  0.98828125,
  0.98828125,
  0.46484375,
  0.09765625,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.17578125,
  0.7265625,
  0.98828125,
  0.98828125,
  0.5859375,
  0.10546875,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0625,
  0.36328125,
  0.984375,
  0.98828125,
  0.73046875,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.97265625,
  0.98828125,
  0.97265625,
  0.25,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.1796875,
  0.5078125,
  0.71484375,
  0.98828125,
  0.98828125,
  0.80859375,
  0.0078125,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.15234375,
  0.578125,
  0.89453125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.9765625,
  0.7109375,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.09375,
  0.4453125,
  0.86328125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.78515625,
  0.3046875,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.08984375,
  0.2578125,
  0.83203125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.7734375,
  0.31640625,
  0.0078125,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0703125,
  0.66796875,
  0.85546875,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.76171875,
  0.3125,
  0.03515625,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.21484375,
  0.671875,
  0.8828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.953125,
  0.51953125,
  0.04296875,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.53125,
  0.98828125,
  0.98828125,
  0.98828125,
  0.828125,
  0.52734375,
  0.515625,
  0.0625,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0]]
plt.imshow(img);

len(img)
28
it = iter(lst1)
def f(): return list(islice(it,28))
f()
[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.01171875,
 0.0703125,
 0.0703125,
 0.0703125,
 0.4921875,
 0.53125,
 0.68359375,
 0.1015625,
 0.6484375,
 0.99609375,
 0.96484375,
 0.49609375,
 0.0,
 0.0,
 0.0,
 0.0]
#img[28]

Matrix and tensors

import torch
from torch import tensor
class Matrix:
    def __init__(self,xs): self.xs = xs
    def __getitem__(self,idxs): return self.xs[idxs[0]][idxs[1]]
img[20][15]
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[16], line 1
----> 1 img[20][15]

NameError: name 'img' is not defined
#img[27,25]
m = Matrix(img)
m[20,15]
b = [1,2,3]
b
b = tensor(b)
b
len(img)
img_tens = tensor(img)
img_tens.shape
img_tens[20,15]
tensor(0.9883)
s = [1,2,3]
s
[1, 2, 3]
def s_to_int(a):
    a = a+1
    return a
sa = list(map(s_to_int,s))
sa
[2, 3, 4]
(x_train,y_train,x_valid,y_valid) = map(tensor, (x_train,y_train,x_valid,y_valid))
x_train.shape,y_train.shape,x_valid.shape,y_valid.shape
(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))
x_train.type()
'torch.FloatTensor'

Tensor

imgs = x_train.reshape((-1,28,28))
imgs.shape
torch.Size([50000, 28, 28])
imgs_a = x_train[0].reshape((-1,28,28))
imgs_a.shape
torch.Size([1, 28, 28])
tyo = x_train[1].shape
#plt.imshow(tyo)
plt.imshow(imgs[0])

imgs[0]
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0117, 0.0703, 0.0703, 0.0703, 0.4922, 0.5312,
         0.6836, 0.1016, 0.6484, 0.9961, 0.9648, 0.4961, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1172,
         0.1406, 0.3672, 0.6016, 0.6641, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883,
         0.8789, 0.6719, 0.9883, 0.9453, 0.7617, 0.2500, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1914, 0.9297,
         0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9805,
         0.3633, 0.3203, 0.3203, 0.2188, 0.1523, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0703, 0.8555,
         0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.7734, 0.7109, 0.9648, 0.9414,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3125,
         0.6094, 0.4180, 0.9883, 0.9883, 0.8008, 0.0430, 0.0000, 0.1680, 0.6016,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0547, 0.0039, 0.6016, 0.9883, 0.3516, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.5430, 0.9883, 0.7422, 0.0078, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0430, 0.7422, 0.9883, 0.2734, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1367, 0.9414, 0.8789, 0.6250, 0.4219, 0.0039,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.3164, 0.9375, 0.9883, 0.9883, 0.4648,
         0.0977, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1758, 0.7266, 0.9883, 0.9883,
         0.5859, 0.1055, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0625, 0.3633, 0.9844,
         0.9883, 0.7305, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9727,
         0.9883, 0.9727, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1797, 0.5078, 0.7148, 0.9883,
         0.9883, 0.8086, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1523, 0.5781, 0.8945, 0.9883, 0.9883, 0.9883,
         0.9766, 0.7109, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0938, 0.4453, 0.8633, 0.9883, 0.9883, 0.9883, 0.9883, 0.7852,
         0.3047, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0898,
         0.2578, 0.8320, 0.9883, 0.9883, 0.9883, 0.9883, 0.7734, 0.3164, 0.0078,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0703, 0.6680, 0.8555,
         0.9883, 0.9883, 0.9883, 0.9883, 0.7617, 0.3125, 0.0352, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.2148, 0.6719, 0.8828, 0.9883, 0.9883,
         0.9883, 0.9883, 0.9531, 0.5195, 0.0430, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.5312, 0.9883, 0.9883, 0.9883, 0.8281,
         0.5273, 0.5156, 0.0625, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000]])
imgs[0,20,15]
tensor(0.9883)
n, c = x_train.shape
n,c
(50000, 784)
y_train, y_train.shape
(tensor([5, 0, 4,  ..., 8, 4, 8]), torch.Size([50000]))
y_train[0].shape
torch.Size([])
min(y_train), max(y_train)
(tensor(0), tensor(9))
y_train.min(), y_train.max()
(tensor(0), tensor(9))
Random numbers

Based on the Wichmann Hill algorithm

rnd_state = None
def seed(a):
    global rnd_state
    a, x = divmod(a,30268)
    a, y = divmod(a, 30306)
    a, z = divmod(a, 30322)
    rnd_state = int(x)+1, int(y)+1, int(z)+1
seed(42)
rnd_state
(43, 1, 1)
seed(457428938475)
rnd_state
(4976, 20238, 499)
def rand():
    global rnd_state
    x, y, z = rnd_state
    x = (171 * x) % 30269
    y = (172 * y) % 30307
    z = (170 * z) % 30323
    rnd_state = x, y, z
    return (x/30269 + y/30307 + z/30323) % 1.0
rand(),rand()
(0.25420336316883324, 0.46884405296716114)
rand()
0.19540525690312815
rand(),rand(),rand(),rand()
(0.28886109883281286,
 0.8643955691976015,
 0.062341103558347655,
 0.5214729908496198)
if os.fork(): print(f'In parent: {rand()}')
else:
    print(f' In child: {rand()}')
    os._exit(os.EX_OK)
In parent: 0.44425801591077185
 In child: 0.44425801591077185
if os.fork(): print(f'In parent: {torch.rand(1)}')
else:
    print(f' In child: {torch.rand()}')
    os._exit(os.EX_OK)
In parent: tensor([0.9308])
import numpy as np
if os.fork(): print(f'In parent: {np.random.rand(1)}')
else:
    print(f' In child: {np.random.rand()}')
    os._exit(os.EX_OK)
 In child: 0.6953703052450606
In parent: [0.69537031]
from random import random
if os.fork(): print(f'In parent: {random()}')
else:
    print(f' In child: {random()}')
    os._exit(os.EX_OK)
In parent: 0.006104636974047728
 In child: 0.22305777947091454
plt.plot([rand() for _ in range(50)]);

plt.hist([rand() for _ in range(1000)]);