Skip to content

Commit

Permalink
Merge branch 'csf-dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
phython96 committed Jan 4, 2025
2 parents c353a45 + 58d88a7 commit 549f727
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 54 deletions.
7 changes: 1 addition & 6 deletions docs/source/inference/baseline-vpt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@ Next, we create the ``env_generator`` and ``agent_generator`` separately to enab
obs_size=(128, 128),
preferred_spawn_biome="forest",
)
agent_generator = partial(
load_vpt_policy,
model_path="/path/to/foundation-model-2x.model",
weights_path="/path/to/rl-from-early-game-2x.weights"
)
agent_generator = lambda: VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x")
Next, we configure the worker parameters, including:
- A maximum of 12,000 steps per episode,
Expand Down
10 changes: 3 additions & 7 deletions docs/source/inference/quick-inference.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<!--
* @Date: 2024-12-02 21:23:42
* @LastEditors: muzhancun muzhancun@126.com
* @LastEditTime: 2024-12-14 01:10:23
* @LastEditors: caishaofei caishaofei@stu.pku.edu.cn
* @LastEditTime: 2025-01-04 11:44:39
* @FilePath: /MineStudio/docs/source/inference/quick-inference.md
-->

Expand All @@ -22,11 +22,7 @@ if __name__ == '__main__':
obs_size = (128, 128),
preferred_spawn_biome = "forest",
) # generate the environment
agent_generator = partial(
load_vpt_policy,
model_path = # provide the path to the model
weights_path = # provide the path to the weights
) # generate the agent
agent_generator = lambda: VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x") # generate the agent
worker_kwargs = dict(
env_generator = env_generator,
agent_generator = agent_generator,
Expand Down
4 changes: 3 additions & 1 deletion docs/source/models/baseline-rocket1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ Evaluating the trained ROCKET-1 in your own scripts is easy:
.. code-block:: python
import torch
from minestudio.models import load_rocket_policy
from minestudio.models import load_rocket_policy, RocketPolicy
model = load_rocket_policy('/path/to/rocket.ckpt').to('cuda')
# or
model = RocketPolicy.from_pretrained("CraftJarvis/MineStudio_ROCKET-1.12w_EMA").to("cuda")
model.eval()
input = {
'image': torch.zeros(224, 224, 3).to("cuda"),
Expand Down
7 changes: 6 additions & 1 deletion docs/source/models/quick-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@ Here is an example that shows how to load the OpenAI's VPT policy in the Minecra
```python
from minestudio.simulator import MinecraftSim
from minestudio.simulator.callbacks import RecordCallback
from minestudio.models import load_vpt_policy
from minestudio.models import load_vpt_policy, VPTPolicy

# load the policy from the local model files
policy = load_vpt_policy(
model_path="/path/to/foundation-model-2x.model",
weights_path="/path/to/foundation-model-2x.weights"
).to("cuda")

# or load the policy from the Hugging Face model hub
policy = VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x").to("cuda")

policy.eval()

env = MinecraftSim(
Expand Down
6 changes: 3 additions & 3 deletions docs/source/offline/tutorial-vpt.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ The following code snippet shows how to finetune a VPT policy to hunt animals in
# below are MineStudio dependencies
from minestudio.data import MineDataModule
from minestudio.offline import MineLightning
from minestudio.models import load_vpt_policy
from minestudio.models import load_vpt_policy, VPTPolicy
from minestudio.offline.mine_callbacks import BehaviorCloneCallback
from minestudio.offline.lightning_callbacks import SmartCheckpointCallback, SpeedMonitorCallback
```

2. Configure the policy model and the training process:
```python
policy = load_vpt_policy(model_path="/path/to/1x.model", weights_path="/path/to/1x.weights")
policy = VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.foundation_model_2x")
mine_lightning = MineLightning(
mine_policy=policy,
learning_rate=0.00004,
Expand All @@ -36,7 +36,7 @@ The following code snippet shows how to finetune a VPT policy to hunt animals in
mine_data = MineDataModule(
data_params=dict(
mode='event',
dataset_dirs=['/nfs-shared-2/data/contractors/dataset_6xx'],
dataset_dirs=['10xx'],
win_len=128,
frame_width=128,
frame_height=128,
Expand Down
16 changes: 7 additions & 9 deletions minestudio/models/rocket_one/body.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-10 15:52:16
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2025-01-03 10:16:13
LastEditTime: 2025-01-04 09:34:49
FilePath: /MineStudio/minestudio/models/rocket_one/body.py
'''
import torch
Expand All @@ -12,12 +12,13 @@
from typing import List, Dict, Any, Tuple, Optional

import timm
from huggingface_hub import PyTorchModelHubMixin
from minestudio.models.base_policy import MinePolicy
from minestudio.utils.vpt_lib.util import FanInInitReLULayer, ResidualRecurrentBlocks
from minestudio.utils.register import Registers

@Registers.model.register
class RocketPolicy(MinePolicy):
class RocketPolicy(MinePolicy, PyTorchModelHubMixin):

def __init__(self,
backbone: str = 'efficientnet_b0.ra_in1k',
Expand Down Expand Up @@ -118,13 +119,10 @@ def load_rocket_policy(ckpt_path: str):
return model

if __name__ == '__main__':
# ckpt_path = "/nfs-shared-2/shaofei/minestudio/save/2024-11-25/14-39-15/checkpoints/step-step=120000.ckpt"
# model = load_rocket_policy(ckpt_path).to("cuda")
model = RocketPolicy(
backbone='efficientnet_b0.ra_in1k',
hiddim=1024,
num_layers=4,
).to("cuda")
# model = load_rocket_policy("/nfs-shared-2/shaofei/minestudio/save/2025-01-02/00-54-08/weights/weight-epoch=5-step=120000-EMA.ckpt")
# model.push_to_hub("CraftJarvis/MineStudio_ROCKET-1.12w_EMA")
model = RocketPolicy.from_pretrained("CraftJarvis/MineStudio_ROCKET-1.12w_EMA").to("cuda")

num_params = sum(p.numel() for p in model.parameters())
print(f"Params (MB): {num_params / 1e6 :.2f}")

Expand Down
2 changes: 1 addition & 1 deletion minestudio/models/steve_one/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def load_steve_one_policy(ckpt_path: str) -> SteveOnePolicy:
return SteveOnePolicy.from_pretrained(ckpt_path)

if __name__ == '__main__':
model = SteveOnePolicy.from_pretrained("zbww/Steve-1").to("cuda")
model = SteveOnePolicy.from_pretrained("CraftJarvis/MineStudio_STEVE-1.official").to("cuda")
model.eval()
condition = model.prepare_condition(
{
Expand Down
20 changes: 13 additions & 7 deletions minestudio/models/vpt/body.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-11 20:54:15
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-12-15 11:30:55
LastEditTime: 2025-01-04 11:04:08
FilePath: /MineStudio/minestudio/models/vpt/body.py
'''
import os
Expand All @@ -16,6 +16,7 @@
from copy import deepcopy
from typing import List, Dict, Optional, Callable, Union, Tuple, Any

from huggingface_hub import PyTorchModelHubMixin
from minestudio.utils.vpt_lib.impala_cnn import ImpalaCNN
from minestudio.utils.vpt_lib.util import FanInInitReLULayer, ResidualRecurrentBlocks
from minestudio.models.base_policy import MinePolicy
Expand Down Expand Up @@ -224,7 +225,7 @@ def initial_state(self, batchsize):
return None

@Registers.model.register
class VPTPolicy(MinePolicy):
class VPTPolicy(MinePolicy, PyTorchModelHubMixin):

def __init__(self, policy_kwargs, action_space=None, **kwargs):
super().__init__(hiddim=policy_kwargs["hidsize"], action_space=action_space, **kwargs)
Expand Down Expand Up @@ -307,10 +308,15 @@ def load_vpt_policy(model_path: str, weights_path: Optional[str] = None):
return vpt_policy

if __name__ == '__main__':
model_path = '/nfs-shared/jarvisbase/pretrained/foundation-model-2x.model'
weights_path = '/nfs-shared/jarvisbase/pretrained/rl-from-early-game-2x.weights'
policy = load_vpt_policy(model_path, weights_path)
# model = load_vpt_policy(
# model_path="/nfs-shared/jarvisbase/pretrained/foundation-model-2x.model",
# weights_path="/nfs-shared/jarvisbase/pretrained/rl-from-early-game-2x.weights"
# ).to("cuda")
# model.push_to_hub("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x")
model = VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x").to("cuda")
model.eval()
dummy_input = {
"image": torch.zeros(1, 1, 128, 128, 3),
"image": torch.zeros(1, 1, 128, 128, 3).to("cuda"),
}
output, memory = policy(dummy_input, None)
output, memory = model(dummy_input, None)
print(output)
19 changes: 10 additions & 9 deletions minestudio/tutorials/inference/evaluate_vpts/hunt_animals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
'''
Date: 2024-12-13 14:31:12
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-12-13 15:23:35
FilePath: /MineStudio/minestudio/tutorials/inference/evaluate_hunt_vpt.py
LastEditTime: 2025-01-04 11:05:24
FilePath: /MineStudio/minestudio/tutorials/inference/evaluate_vpts/hunt_animals.py
'''
import ray
from rich import print
from minestudio.inference import EpisodePipeline, MineGenerator, InfoBaseFilter

from functools import partial
from minestudio.models import load_vpt_policy
from minestudio.models import load_vpt_policy, VPTPolicy
from minestudio.simulator import MinecraftSim
from minestudio.simulator.callbacks import (
SpeedTestCallback,
Expand All @@ -36,12 +36,13 @@
]),
]
)
agent_generator = partial(
load_vpt_policy,
model_path="/nfs-shared/jarvisbase/pretrained/foundation-model-2x.model",
# weights_path="/nfs-shared/jarvisbase/pretrained/foundation-model-1x.weights"
weights_path="/nfs-shared-2/shaofei/minestudio/save/2024-12-13/23-01-45/weights/weight-epoch=2-step=1000.ckpt",
)
# agent_generator = partial(
# load_vpt_policy,
# model_path="/nfs-shared/jarvisbase/pretrained/foundation-model-2x.model",
# # weights_path="/nfs-shared/jarvisbase/pretrained/foundation-model-1x.weights"
# weights_path="/nfs-shared-2/shaofei/minestudio/save/2024-12-13/23-01-45/weights/weight-epoch=2-step=1000.ckpt",
# )
agent_generator = lambda: VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x")
worker_kwargs = dict(
env_generator=env_generator,
agent_generator=agent_generator,
Expand Down
10 changes: 3 additions & 7 deletions minestudio/tutorials/inference/evaluate_vpts/mine_diamond.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
'''
Date: 2024-11-25 08:11:33
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-12-15 14:17:07
LastEditTime: 2025-01-04 11:39:20
FilePath: /MineStudio/minestudio/tutorials/inference/evaluate_vpts/mine_diamond.py
'''
import ray
from rich import print
from minestudio.inference import EpisodePipeline, MineGenerator, InfoBaseFilter

from functools import partial
from minestudio.models import load_vpt_policy
from minestudio.models import load_vpt_policy, VPTPolicy
from minestudio.simulator import MinecraftSim
from minestudio.simulator.callbacks import SpeedTestCallback

Expand All @@ -23,11 +23,7 @@
SpeedTestCallback(50),
],
)
agent_generator = partial(
load_vpt_policy,
model_path="/nfs-shared/jarvisbase/pretrained/foundation-model-2x.model",
weights_path="/nfs-shared/jarvisbase/pretrained/rl-from-early-game-2x.weights"
)
agent_generator = lambda: VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x")
worker_kwargs = dict(
env_generator=env_generator,
agent_generator=agent_generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ shuffle_episodes: True
episode_continuous_batch: True

model:
backbone: 'efficientnet_b0.ra_in1k'
backbone: 'timm/vit_base_patch16_224.dino'
hiddim: 1024
num_heads: 8
num_layers: 4
Expand Down
4 changes: 2 additions & 2 deletions minestudio/utils/vpt_lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def __init__(self, hidsize: int):
super().__init__()
self.fc = nn.Linear(hidsize, 2*hidsize)
self.ff = FeedForward(hidsize, mult = 2)
# self.norm = nn.LayerNorm(hidsize)
self.norm = nn.LayerNorm(hidsize)

def forward(self, x, z):
z = self.ff(z) + z
gamma, beta = self.fc(z).chunk(2, dim=-1)
y = x * (1 + gamma) + beta
# y = self.norm(y) #! LayerNorm after FiLM
y = self.norm(y) #! LayerNorm after FiLM
return y


Expand Down

0 comments on commit 549f727

Please sign in to comment.