diff --git a/README.md b/README.md
index 0d8bcf6..e833b80 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,8 @@
# PyTorch in JavaScript
- JS-PyTorch is a Deep Learning **JavaScript library** built from scratch, to closely follow PyTorch's syntax.
-- It contains a fully functional [Tensor](src/tensor.ts) object, which can track gradients, Deep Learning [Layers](src/layers.ts) and functions, and an **Automatic Differentiation** engine.
+- This library has **GPU support**, using GPU.js.
+- It contains a gradient-tracking [Tensor](src/tensor.ts) object, Deep Learning [Layers](src/layers.ts) and functions, and an **Automatic Differentiation** engine.
- Feel free to try out the Web Demo!
> **Note:** You can install the package locally with: `npm install js-pytorch`
@@ -94,7 +95,7 @@ const { torch } = require("js-pytorch");
// Instantiate Tensors:
let x = torch.randn([8, 4, 5]);
-let w = torch.randn([8, 5, 4], (requires_grad = true));
+let w = torch.randn([8, 5, 4], (requires_grad = true), (device = 'gpu));
let b = torch.tensor([0.2, 0.5, 0.1, 0.0], (requires_grad = true));
// Make calculations:
@@ -115,6 +116,7 @@ console.log(b.grad);
const { torch } = require("js-pytorch");
const nn = torch.nn;
const optim = torch.optim;
+const device = 'gpu';
// Define training hyperparameters:
const vocab_size = 52;
@@ -126,15 +128,15 @@ const batch_size = 8;
// Create Transformer decoder Module:
class Transformer extends nn.Module {
- constructor(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p) {
+ constructor(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p, device) {
super();
// Instantiate Transformer's Layers:
this.embed = new nn.Embedding(vocab_size, hidden_size);
this.pos_embed = new nn.PositionalEmbedding(n_timesteps, hidden_size);
- this.b1 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps,dropout_p);
- this.b2 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps,dropout_p);
+ this.b1 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p, device);
+ this.b2 = new nn.Block(hidden_size, hidden_size, n_heads, n_timesteps, dropout_p, device);
this.ln = new nn.LayerNorm(hidden_size);
- this.linear = new nn.Linear(hidden_size, vocab_size);
+ this.linear = new nn.Linear(hidden_size, vocab_size, device);
}
forward(x) {
@@ -149,7 +151,7 @@ class Transformer extends nn.Module {
}
// Instantiate your custom nn.Module:
-const model = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p);
+const model = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p, device);
// Define loss function and optimizer:
const loss_func = new nn.CrossEntropyLoss();
@@ -182,6 +184,25 @@ for (let i = 0; i < 40; i++) {
}
```
+### Saving and Loading models:
+
+```typescript
+// Instantiate your model:
+const model = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p);
+
+// Train the model:
+trainModel(model);
+
+// Save model to JSON file:
+torch.save(model, 'model.json')
+
+// To load, instantiate placeHolder using the original model's architecture:
+const placeHolder = new Transformer(vocab_size, hidden_size, n_timesteps, n_heads, dropout_p);
+// Load weights into placeHolder:
+const newModel = torch.load(placeHolder, 'model.json')
+```
+
+
## 3. Distribution & Devtools
diff --git a/assets/demo/bundle.js b/assets/demo/bundle.js
new file mode 100644
index 0000000..54befdd
Binary files /dev/null and b/assets/demo/bundle.js differ
diff --git a/assets/demo/demo.css b/assets/demo/demo.css
index 46932d3..97ffe76 100644
--- a/assets/demo/demo.css
+++ b/assets/demo/demo.css
@@ -141,6 +141,20 @@ button:active {
background-color: #821414;
}
+.device-button {
+ width: 10%;
+ background-color: #818181;
+ padding-left: 1%;
+}
+
+.device-button:hover {
+ background-color: #5f5f5f;
+}
+
+.device-button:active {
+ background-color: #3a3a3a;
+}
+
.icon {
width: 72%;
margin: auto;
@@ -201,4 +215,3 @@ input {
font-size: 16px;
color:#1b1b1b;
}
-
diff --git a/assets/demo/demo.html b/assets/demo/demo.html
index c6dd3b7..2b19aa6 100644
--- a/assets/demo/demo.html
+++ b/assets/demo/demo.html
@@ -6,28 +6,28 @@
Welcome to JS-Torch's Web Demo! You can choose the Model Hyperparameters on the left, set the Model Layers on the right (number of layers and hidden dimension on each).
+Welcome to JS-PyTorch's Web Demo! You can choose the Model Hyperparameters on the left, set the Model Layers on the right (number of layers and hidden dimension on each).
Iteration:
Total Training Examples:
Loss:
+Device: