-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
118 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Tic Tac Transformer | ||
|
||
A tiny GPT trained to play tic-tac-toe | ||
|
||
## How does it work? | ||
|
||
We teach a language model to speak tic-tac-toe | ||
|
||
The language is simple - there are 11 tokens | ||
|
||
- **0-8**: moves on the board | ||
- **9**: start game | ||
- **10**: pad | ||
|
||
The sequence length is 10, so a game always starts with <9> and can at most fill the board | ||
|
||
Players take turns | ||
|
||
Duplicate moves are illegal | ||
|
||
**Example** | ||
|
||
seq: [9, 4, 0, 2, 1, 6, 10, 10, 10, 10] | ||
|
||
- player 1 puts an X at position 4 (the middle) | ||
- player 2 puts an O at position 3 (top left) | ||
- player 1 puts an X at position 2 (top right) | ||
- player 2 puts an O at position 1 (top middle) | ||
- player 1 puts an X at position 6 (bottom left) | ||
- padding | ||
|
||
``` | ||
[O] [O] [X] | ||
[ ] [X] [ ] | ||
[X] [ ] [ ] | ||
``` | ||
|
||
player 1 wins | ||
|
||
## Try for yourself | ||
|
||
Generate pre-training data | ||
|
||
```bash | ||
python generate_data.py | ||
``` | ||
|
||
Run pre-training | ||
|
||
```bash | ||
python train.py | ||
``` | ||
|
||
RL fine-tuning | ||
|
||
```bash | ||
python reinforcement_learn.py | ||
``` | ||
|
||
Run benchmark | ||
|
||
```bash | ||
python benchmark.py | ||
``` | ||
|
||
Play the AI! | ||
|
||
```bash | ||
python play_ai.py | ||
``` |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
from tokens import START | ||
from board_ops import check_winner, board_full, get_valid_moves | ||
import numpy as np | ||
from setup import load_from_checkpoint, device | ||
|
||
|
||
model = load_from_checkpoint() | ||
model.eval() | ||
model.to(device) | ||
|
||
with torch.no_grad(): | ||
board = np.zeros((3, 3), dtype=int) | ||
player = 1 | ||
winner = None | ||
moves = [START] | ||
while winner is None and not board_full(board): | ||
if player == 1: | ||
x = torch.tensor(moves, dtype=torch.long, device=device)[None, ...] | ||
y = model.generate(x, max_new_tokens=1, temperature=1.0, top_k=3) | ||
y = y[0][-1].item() | ||
|
||
if y not in set(range(9)) or y in moves: | ||
print(f"AI used invalid move: {y} moves: {moves}") | ||
winner = None | ||
break | ||
|
||
i, j = y // 3, y % 3 | ||
else: | ||
valid = [i * 3 + j for i, j in get_valid_moves(board)] | ||
y = None | ||
while y not in valid: | ||
y = input("Your move! (a number from 0-8): ") | ||
try: | ||
y = int(y) | ||
except: | ||
print("invalid") | ||
y = None | ||
|
||
i, j = y // 3, y % 3 | ||
|
||
moves.append(i * 3 + j) | ||
board[i][j] = player | ||
|
||
print(board) | ||
|
||
player *= -1 | ||
winner = check_winner(board) |