Skip to content

Commit

Permalink
⚡ Format+bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
joey00072 committed Jan 24, 2024
1 parent 2a8dcdc commit d192738
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 84 deletions.
41 changes: 20 additions & 21 deletions gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
learning_rate = 3e-4
device = "cuda" # "cuda" if torch.cuda.is_available() else "cpu"
eval_iters = 100
n_embd = 128*2
n_embd = 128 * 2
n_head = 4
n_layer = 2
dropout = 0.2
Expand Down Expand Up @@ -117,7 +117,7 @@ def forward(self, x):
def estimate_loss():
out = {}
model.eval()

for split in ["train", "val"]:
losses = []
for k in range(eval_iters):
Expand All @@ -129,7 +129,7 @@ def estimate_loss():
targets = targets.view(B * T)
loss = F.cross_entropy(logits, targets)
losses.append(loss.item())
out[split] = sum(losses)/len(losses)
out[split] = sum(losses) / len(losses)
model.train()
return out

Expand Down Expand Up @@ -165,22 +165,19 @@ def forward(self, x: torch.Tensor):
q = q.reshape(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
v = v.reshape(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)


attn = self.attention(k,q,v,self.mask)
attn = self.attention(k, q, v, self.mask)
v = attn.transpose(1, 2).reshape(B, T, C)
x = self.proj(v)
return x

@staticmethod
def attention(k,q,v,mask):
B,n_head,T,C = k.shape
def attention(k, q, v, mask):
B, n_head, T, C = k.shape
wei = (q @ k.transpose(-1, -2)) * (C**-0.5)
wei = mask[:, :, :T, :T] + wei
wei = F.softmax(wei, dim=-1)
x = wei @ v
return x




class MLP(nn.Module):
Expand Down Expand Up @@ -238,9 +235,10 @@ def forward(self, x):

return logits


@torch.no_grad()
def generate(model, idx, max_new_tokens):
idx = torch.zeros((1,block_size)).to(device).long()
idx = torch.zeros((1, block_size)).to(device).long()
for i in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits = model(idx_cond)
Expand All @@ -249,7 +247,7 @@ def generate(model, idx, max_new_tokens):
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
model.train()
return idx[:,block_size:]
return idx[:, block_size:]


model_args = ModelArgs(
Expand All @@ -268,35 +266,36 @@ def generate(model, idx, max_new_tokens):
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


for iter in range(1,max_iters):
for iter in range(1, max_iters):
if iter % eval_interval == 0 or iter == max_iters - 1:
print("="*50)
print("=" * 50)
losses = estimate_loss()
print(
f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
)
context = torch.zeros((1, 1)).to(device).long()
print(tokenizer.decode(generate(model,context, max_new_tokens=500)[0].tolist()))
print(
tokenizer.decode(generate(model, context, max_new_tokens=500)[0].tolist())
)
optimizer.zero_grad()
print("-"*50)
print("-" * 50)

data, targets = get_batch("train")
logits = model(data)
# print(sum(model.token_embedding.weight.reshape(-1).tolist()))


B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.view(B * T)

loss = F.cross_entropy(logits, targets)

loss.backward()
optimizer.step()
optimizer.zero_grad()
if iter%50==0:

if iter % 50 == 0:
print(f"{iter=} {loss.item()=}")

context = torch.zeros((1, 1)).to(device).long()
print(tokenizer.decode(generate(model,context, max_new_tokens=500)[0].tolist()))
print(tokenizer.decode(generate(model, context, max_new_tokens=500)[0].tolist()))
24 changes: 13 additions & 11 deletions mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ def download_mnist():
url = base_url + file
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(file_path, 'wb') as f:
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {file}")
else:
print(f"Failed to download {file}. HTTP Response Code: {response.status_code}")
print(
f"Failed to download {file}. HTTP Response Code: {response.status_code}"
)


def load_mnist() -> tuple:
def read_labels(filename: str) -> np.array:
Expand Down Expand Up @@ -89,17 +92,18 @@ def forward(self, x: tt.Tensor) -> tt.Tensor:
x = tt.tanh(self.l1(x))
return self.l2(x)


@tt.no_grad()
def test(model: Network, test_images: tt.Tensor, test_labels: tt.Tensor):
preds = model.forward(test_images)
pred_indices = tt.argmax(preds, axis=-1).numpy()
test_labels = test_labels.numpy()
correct = 0
for p,t in zip(pred_indices.reshape(-1),test_labels.reshape(-1)):
if p==t:
correct+=1
accuracy= correct/ len(test_labels)

correct = 0
for p, t in zip(pred_indices.reshape(-1), test_labels.reshape(-1)):
if p == t:
correct += 1
accuracy = correct / len(test_labels)
print(f"Test accuracy: {accuracy:.2%}")


Expand Down Expand Up @@ -131,9 +135,7 @@ def train(
download_mnist()
(train_images, train_labels), (test_images, test_labels) = load_mnist()

train_labels, test_labels = map(
tt.tensor, [train_labels, test_labels]
)
train_labels, test_labels = map(tt.tensor, [train_labels, test_labels])

train_images = tt.tensor(train_images.reshape(-1, 28 * 28) / 255).float()
test_images = tt.tensor(test_images.reshape(-1, 28 * 28) / 255).float()
Expand Down
76 changes: 32 additions & 44 deletions test_tinytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,12 @@ def test_sigmoid():

def softmax(x: torch.Tensor, dim: int = 0) -> torch.Tensor:
m, _ = x.max(axis=dim, keepdims=True)
e_x = (x - m ).exp()
e_x = (x - m).exp()
return e_x / e_x.sum(axis=dim, keepdims=True)


def test_softmax():
x = np.random.rand(5,3) # Create a random 5x3 matrix
x = np.random.rand(5, 3) # Create a random 5x3 matrix

# Convert to PyTorch tensor
x_t = torch.tensor(x, requires_grad=True)
Expand All @@ -485,10 +485,8 @@ def test_softmax():
), "Gradients do not match between PyTorch and tinytorch."




def test_softmax2():
x = np.random.rand(5,3) # Create a random 5x3 matrix
x = np.random.rand(5, 3) # Create a random 5x3 matrix
x[-1][-1] = -np.inf

# Convert to PyTorch tensor
Expand Down Expand Up @@ -516,67 +514,57 @@ def test_softmax2():
), "Gradients do not match between PyTorch and tinytorch."



def test_attention():

def softmax(x: torch.Tensor, dim: int = 0) -> torch.Tensor:
m, _ = x.max(axis=dim, keepdims=True)
e_x = (x - m ).exp()
e_x = (x - m).exp()
return e_x / e_x.sum(axis=dim, keepdims=True)

def attention(k,q,v,mask):
B,n_head,T,C = k.shape
def attention(k, q, v, mask):
B, n_head, T, C = k.shape
wei = (q @ k.transpose(-1, -2)) * (C**-0.5)
wei = mask[:, :, :T, :T] + wei
wei = softmax(wei, dim=-1)
x = wei @ v
return x
B,n_head,T,C = 3,5,7,9

B, n_head, T, C = 3, 5, 7, 9
seq_len = 20
k = np.random.rand(*(B,n_head,T,C))
q = np.random.rand(*(B,n_head,T,C))
v = np.random.rand(*(B,n_head,T,C))



mask = (
np.tril(np.zeros((1, 1, seq_len, seq_len)))
+ np.triu(
-np.inf * np.ones((1, 1, seq_len, seq_len)),
k=1,
)
)

kt = torch.tensor(k,requires_grad=True)
qt = torch.tensor(q,requires_grad=True)
vt = torch.tensor(v,requires_grad=True)
maskt = torch.tensor(mask,requires_grad=True)
outt = attention(kt,qt,vt,maskt)
k = np.random.rand(*(B, n_head, T, C))
q = np.random.rand(*(B, n_head, T, C))
v = np.random.rand(*(B, n_head, T, C))

mask = np.tril(np.zeros((1, 1, seq_len, seq_len))) + np.triu(
-np.inf * np.ones((1, 1, seq_len, seq_len)),
k=1,
)

kt = torch.tensor(k, requires_grad=True)
qt = torch.tensor(q, requires_grad=True)
vt = torch.tensor(v, requires_grad=True)
maskt = torch.tensor(mask, requires_grad=True)
outt = attention(kt, qt, vt, maskt)
outt.sum().backward()
ktt = tinytorch.tensor(k,requires_grad=True)
qtt = tinytorch.tensor(q,requires_grad=True)
vtt = tinytorch.tensor(v,requires_grad=True)
masktt = tinytorch.tensor(mask,requires_grad=True)
outtt = attention(ktt,qtt,vtt,masktt)

ktt = tinytorch.tensor(k, requires_grad=True)
qtt = tinytorch.tensor(q, requires_grad=True)
vtt = tinytorch.tensor(v, requires_grad=True)
masktt = tinytorch.tensor(mask, requires_grad=True)
outtt = attention(ktt, qtt, vtt, masktt)
outtt.sum().backward()




assert np.allclose(
outt.detach().numpy(), outtt.detach().numpy() , atol=1e-5
outt.detach().numpy(), outtt.detach().numpy(), atol=1e-5
), "Gradients do not match between PyTorch and tinytorch."

assert np.allclose(
kt.grad.numpy(), ktt.grad.data, atol=1e-5
), "Gradients do not match between PyTorch and tinytorch."

assert np.allclose(
qt.grad.numpy(), qtt.grad.data, atol=1e-5
), "Gradients do not match between PyTorch and tinytorch."

assert np.allclose(
vt.grad.numpy(), vtt.grad.data, atol=1e-5
), "Gradients do not match between PyTorch and tinytorch."

12 changes: 4 additions & 8 deletions tinytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def backward(ctx, grad):
grad_x[slice_args] = grad.data
else:
for s in np.array(slice_args).reshape(-1):
grad_x[s] += grad.data[s]
grad_x[slice_args] = grad_x[slice_args] + grad.data[s]
return Tensor(grad_x), None


Expand Down Expand Up @@ -484,6 +484,7 @@ def backward(ctx: Function, grad: Tensor) -> Tensor:

return Tensor(grad_x), None


class Power(Function):
@staticmethod
def forward(x, y):
Expand Down Expand Up @@ -895,11 +896,6 @@ def step(self):


if __name__ == "__main__":
x = tensor(2, requires_grad=True)

def f(x):
return (x + 1) / x

z = f(x)
z.backward()
x = Parameter(Tensor([0, 1]))
x[[0, 0]].sum().backward()
print(x.grad)

0 comments on commit d192738

Please sign in to comment.