from pathlib import Path
import pickle, gzip, os, math, time, shutil , matplotlib as mpl, matplotlib.pyplot as plt
Stable Diffusion from the ground up
Here we shall build a working Stable Diffusion model using just Python, the Python standard library, Matplotlib for plots, Jupyter Notebook, which is where we shall be writing our code, and nbdev, which we can use to create modules from notebooks.
Get data
= 'https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
MNIST_URL = Path('data')
path_data =True)
path_data.mkdir(exist_ok= path_data/'mnist.pkl.gz' path_gz
from urllib.request import urlretrieve
urlretrieve
<function urllib.request.urlretrieve(url, filename=None, reporthook=None, data=None)>
if not path_gz.exists(): urlretrieve(MNIST_URL,path_gz)
!ls -l data
total 16656
-rw-r--r-- 1 rubanza rubanza 17051982 Jul 21 16:23 mnist.pkl.gz
-rw-r–r– 1 user staff 15296311 Jul 21 10:30 mnist.pkl.gz | | | | | | | | | | | | | └── Filename | | | | | └── Modification time | | | | └── File size (bytes) | | | └── Group owner | | └── User owner
| └── Number of hard links └── File permissions
path_gz
PosixPath('data/mnist.pkl.gz')
#with gzip.open('data/mnist.pkl.gz','rb') as f:
#dataset = pickle.load(f, encoding='latin-1')
dataset
((array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
array([5, 0, 4, ..., 8, 4, 8])),
(array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
array([3, 8, 6, ..., 5, 6, 8])),
(array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
array([7, 2, 1, ..., 4, 5, 6])))
with gzip.open(path_gz, 'rb') as f: ((x_train,y_train),(x_valid,y_valid),_) = pickle.load(f, encoding='latin-1')
x_train.shape,y_train.shape,x_valid.shape,y_valid.shape
((50000, 784), (50000,), (10000, 784), (10000,))
x_train
array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)
= list(x_train[0])
lst1 = lst1[200:210]
vals vals
[0.0,
0.0,
0.0,
0.19140625,
0.9296875,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125]
len(lst1)
784
len(lst1)
784
def chunks(x,sz):
for i in range(0, len(x), sz): yield x[i:i+5]
list(chunks(vals,5))
[[0.0, 0.0, 0.0, 0.19140625, 0.9296875],
[0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.98828125]]
= chunks(vals,5) val_iter
val_iter
<generator object chunks at 0x715981808580>
next(val_iter)
[0.0, 0.0, 0.0, 0.19140625, 0.9296875]
next(val_iter)
--------------------------------------------------------------------------- StopIteration Traceback (most recent call last) Cell In[78], line 1 ----> 1 next(val_iter) StopIteration:
'image.cmap']='gray'
mpl.rcParams[list(chunks(lst1,28))); plt.imshow(
'image.cmap'] = 'gray'
plt.rcParams[list(chunks(lst1, 28)))
plt.imshow(0, 28) # Force x-axis to go 0-28
plt.xlim(0, 28) # Force y-axis to go 0-28 plt.ylim(
= [1,2,3,4,5]
a len(a)
5
for i in range(0,5):
print(i+1)
1
2
3
4
5
from itertools import islice
= iter(vals) it
vals
[0.0,
0.0,
0.0,
0.19140625,
0.9296875,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125]
it
<list_iterator at 0x715981848730>
next(it)
0.0
next(it),next(it),next(it)
(0.0, 0.0, 0.19140625)
= iter(vals)
it 28) islice(it,
<itertools.islice at 0x76d844786660>
= islice(it,5)
isit isit
<itertools.islice at 0x7159819054e0>
next(isit),next(isit),next(isit)
(0.0, 0.0, 0.0)
next(isit)
0.19140625
next(isit)
0.9296875
next(isit)
--------------------------------------------------------------------------- StopIteration Traceback (most recent call last) Cell In[140], line 1 ----> 1 next(isit) StopIteration:
next(isit)
0.19140625
list(islice(it,5))
[0.98828125, 0.98828125, 0.98828125]
list(islice(it,28))
[0.0,
0.0,
0.0,
0.19140625,
0.9296875,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125]
list(islice(it,28))
[]
next(it)
--------------------------------------------------------------------------- StopIteration Traceback (most recent call last) Cell In[122], line 1 ----> 1 next(it) StopIteration:
2) islice(it,
<itertools.islice at 0x715981af47c0>
list(islice(it,5))
[]
list(islice(it,5))
[]
using lamda
= iter(lst1)
it = list(iter(lambda: list(islice(it, 28)), []))
img img
[[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.01171875,
0.0703125,
0.0703125,
0.0703125,
0.4921875,
0.53125,
0.68359375,
0.1015625,
0.6484375,
0.99609375,
0.96484375,
0.49609375,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.1171875,
0.140625,
0.3671875,
0.6015625,
0.6640625,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.87890625,
0.671875,
0.98828125,
0.9453125,
0.76171875,
0.25,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.19140625,
0.9296875,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98046875,
0.36328125,
0.3203125,
0.3203125,
0.21875,
0.15234375,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0703125,
0.85546875,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.7734375,
0.7109375,
0.96484375,
0.94140625,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.3125,
0.609375,
0.41796875,
0.98828125,
0.98828125,
0.80078125,
0.04296875,
0.0,
0.16796875,
0.6015625,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0546875,
0.00390625,
0.6015625,
0.98828125,
0.3515625,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.54296875,
0.98828125,
0.7421875,
0.0078125,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.04296875,
0.7421875,
0.98828125,
0.2734375,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.13671875,
0.94140625,
0.87890625,
0.625,
0.421875,
0.00390625,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.31640625,
0.9375,
0.98828125,
0.98828125,
0.46484375,
0.09765625,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.17578125,
0.7265625,
0.98828125,
0.98828125,
0.5859375,
0.10546875,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0625,
0.36328125,
0.984375,
0.98828125,
0.73046875,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.97265625,
0.98828125,
0.97265625,
0.25,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.1796875,
0.5078125,
0.71484375,
0.98828125,
0.98828125,
0.80859375,
0.0078125,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.15234375,
0.578125,
0.89453125,
0.98828125,
0.98828125,
0.98828125,
0.9765625,
0.7109375,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.09375,
0.4453125,
0.86328125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.78515625,
0.3046875,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.08984375,
0.2578125,
0.83203125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.7734375,
0.31640625,
0.0078125,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0703125,
0.66796875,
0.85546875,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.76171875,
0.3125,
0.03515625,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.21484375,
0.671875,
0.8828125,
0.98828125,
0.98828125,
0.98828125,
0.98828125,
0.953125,
0.51953125,
0.04296875,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.53125,
0.98828125,
0.98828125,
0.98828125,
0.828125,
0.52734375,
0.515625,
0.0625,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0]]
; plt.imshow(img)
len(img)
28
= iter(lst1)
it def f(): return list(islice(it,28))
f()
[0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.01171875,
0.0703125,
0.0703125,
0.0703125,
0.4921875,
0.53125,
0.68359375,
0.1015625,
0.6484375,
0.99609375,
0.96484375,
0.49609375,
0.0,
0.0,
0.0,
0.0]
#img[28]
Matrix and tensors
import torch
from torch import tensor
class Matrix:
def __init__(self,xs): self.xs = xs
def __getitem__(self,idxs): return self.xs[idxs[0]][idxs[1]]
20][15] img[
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[16], line 1 ----> 1 img[20][15] NameError: name 'img' is not defined
#img[27,25]
= Matrix(img)
m 20,15] m[
= [1,2,3]
b b
= tensor(b)
b b
len(img)
= tensor(img)
img_tens img_tens.shape
20,15] img_tens[
tensor(0.9883)
= [1,2,3]
s s
[1, 2, 3]
def s_to_int(a):
= a+1
a return a
= list(map(s_to_int,s))
sa sa
[2, 3, 4]
= map(tensor, (x_train,y_train,x_valid,y_valid)) (x_train,y_train,x_valid,y_valid)
x_train.shape,y_train.shape,x_valid.shape,y_valid.shape
(torch.Size([50000, 784]),
torch.Size([50000]),
torch.Size([10000, 784]),
torch.Size([10000]))
type() x_train.
'torch.FloatTensor'
Tensor
= x_train.reshape((-1,28,28))
imgs imgs.shape
torch.Size([50000, 28, 28])
= x_train[0].reshape((-1,28,28))
imgs_a imgs_a.shape
torch.Size([1, 28, 28])
= x_train[1].shape tyo
#plt.imshow(tyo)
0]) plt.imshow(imgs[
0] imgs[
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0117, 0.0703, 0.0703, 0.0703, 0.4922, 0.5312,
0.6836, 0.1016, 0.6484, 0.9961, 0.9648, 0.4961, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1172,
0.1406, 0.3672, 0.6016, 0.6641, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883,
0.8789, 0.6719, 0.9883, 0.9453, 0.7617, 0.2500, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1914, 0.9297,
0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9805,
0.3633, 0.3203, 0.3203, 0.2188, 0.1523, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0703, 0.8555,
0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.7734, 0.7109, 0.9648, 0.9414,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3125,
0.6094, 0.4180, 0.9883, 0.9883, 0.8008, 0.0430, 0.0000, 0.1680, 0.6016,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0547, 0.0039, 0.6016, 0.9883, 0.3516, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.5430, 0.9883, 0.7422, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0430, 0.7422, 0.9883, 0.2734, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.1367, 0.9414, 0.8789, 0.6250, 0.4219, 0.0039,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.3164, 0.9375, 0.9883, 0.9883, 0.4648,
0.0977, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1758, 0.7266, 0.9883, 0.9883,
0.5859, 0.1055, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0625, 0.3633, 0.9844,
0.9883, 0.7305, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9727,
0.9883, 0.9727, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1797, 0.5078, 0.7148, 0.9883,
0.9883, 0.8086, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.1523, 0.5781, 0.8945, 0.9883, 0.9883, 0.9883,
0.9766, 0.7109, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0938, 0.4453, 0.8633, 0.9883, 0.9883, 0.9883, 0.9883, 0.7852,
0.3047, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0898,
0.2578, 0.8320, 0.9883, 0.9883, 0.9883, 0.9883, 0.7734, 0.3164, 0.0078,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0703, 0.6680, 0.8555,
0.9883, 0.9883, 0.9883, 0.9883, 0.7617, 0.3125, 0.0352, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.2148, 0.6719, 0.8828, 0.9883, 0.9883,
0.9883, 0.9883, 0.9531, 0.5195, 0.0430, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.5312, 0.9883, 0.9883, 0.9883, 0.8281,
0.5273, 0.5156, 0.0625, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000]])
0,20,15] imgs[
tensor(0.9883)
= x_train.shape
n, c n,c
(50000, 784)
y_train, y_train.shape
(tensor([5, 0, 4, ..., 8, 4, 8]), torch.Size([50000]))
0].shape y_train[
torch.Size([])
min(y_train), max(y_train)
(tensor(0), tensor(9))
min(), y_train.max() y_train.
(tensor(0), tensor(9))
Random numbers
Based on the Wichmann Hill algorithm
= None
rnd_state def seed(a):
global rnd_state
= divmod(a,30268)
a, x = divmod(a, 30306)
a, y = divmod(a, 30322)
a, z = int(x)+1, int(y)+1, int(z)+1 rnd_state
42)
seed( rnd_state
(43, 1, 1)
457428938475)
seed( rnd_state
(4976, 20238, 499)
def rand():
global rnd_state
= rnd_state
x, y, z = (171 * x) % 30269
x = (172 * y) % 30307
y = (170 * z) % 30323
z = x, y, z
rnd_state return (x/30269 + y/30307 + z/30323) % 1.0
rand(),rand()
(0.25420336316883324, 0.46884405296716114)
rand()
0.19540525690312815
rand(),rand(),rand(),rand()
(0.28886109883281286,
0.8643955691976015,
0.062341103558347655,
0.5214729908496198)
if os.fork(): print(f'In parent: {rand()}')
else:
print(f' In child: {rand()}')
os._exit(os.EX_OK)
In parent: 0.44425801591077185
In child: 0.44425801591077185
if os.fork(): print(f'In parent: {torch.rand(1)}')
else:
print(f' In child: {torch.rand()}')
os._exit(os.EX_OK)
In parent: tensor([0.9308])
import numpy as np
if os.fork(): print(f'In parent: {np.random.rand(1)}')
else:
print(f' In child: {np.random.rand()}')
os._exit(os.EX_OK)
In child: 0.6953703052450606
In parent: [0.69537031]
from random import random
if os.fork(): print(f'In parent: {random()}')
else:
print(f' In child: {random()}')
os._exit(os.EX_OK)
In parent: 0.006104636974047728
In child: 0.22305777947091454
for _ in range(50)]); plt.plot([rand()
for _ in range(1000)]); plt.hist([rand()