Skip to content

Commit

Permalink
Merge branch 'csf-dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
phython96 committed Jan 5, 2025
2 parents 0691045 + bf6cc1e commit 51d5b7b
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 24 deletions.
32 changes: 32 additions & 0 deletions assets/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
FROM nvcr.io/nvidia/pytorch:24.01-py3

RUN apt-get update && \
apt-get install -y \
wget \
git \
gnutls-bin \
openssh-client \
libghc-x11-dev \
gcc-multilib \
g++-multilib \
libglew-dev \
libosmesa6-dev \
libgl1-mesa-glx \
libglfw3 \
xvfb \
mesa-utils \
libegl1-mesa \
libgl1-mesa-dev \
libglu1-mesa-dev \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6 \
unzip \
openjdk-8-jdk

RUN pip install --upgrade pip &&\
pip install MineStudio && \
python -m minestudio.simulator.entry -y

CMD ["python", "-m", "minestudio.simulator.entry"]
18 changes: 9 additions & 9 deletions minestudio/models/groot_one/body.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-25 07:03:41
LastEditors: muzhancun muzhancun@126.com
LastEditTime: 2024-12-14 01:54:42
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2025-01-04 17:05:47
FilePath: /MineStudio/minestudio/models/groot_one/body.py
'''
import torch
Expand Down Expand Up @@ -157,9 +157,10 @@ def forward(self, x: torch.Tensor, memory: List) -> Tuple[torch.Tensor, List]:
return x, memory

def initial_state(self, batch_size: int = None) -> List[torch.Tensor]:
device = next(self.parameters()).device
if batch_size is None:
return [t.squeeze(0).to(self.device) for t in self.recurrent.initial_state(1)]
return [t.to(self.device) for t in self.recurrent.initial_state(batch_size)]
return [t.squeeze(0).to(device) for t in self.recurrent.initial_state(1)]
return [t.to(device) for t in self.recurrent.initial_state(batch_size)]

@Registers.model.register
class GrootPolicy(MinePolicy):
Expand Down Expand Up @@ -281,8 +282,8 @@ def forward(self, input: Dict, memory: Optional[List[torch.Tensor]] = None) -> D
}
return latents, memory

def initial_state(self, **kwargs) -> Any:
return self.decoder.initial_state(**kwargs)
def initial_state(self, *args, **kwargs) -> Any:
return self.decoder.initial_state(*args, **kwargs)

@Registers.model_loader.register
def load_groot_policy(ckpt_path: str = None):
Expand All @@ -300,11 +301,10 @@ def load_groot_policy(ckpt_path: str = None):
return model

if __name__ == '__main__':
load_groot_policy()
model = GrootPolicy(
backbone='vit_base_patch32_clip_224.openai',
backbone='timm/vit_base_patch16_224.dino',
hiddim=1024,
freeze_backbone=True,
freeze_backbone=False,
video_encoder_kwargs=dict(
num_spatial_layers=2,
num_temporal_layers=4,
Expand Down
4 changes: 2 additions & 2 deletions minestudio/models/rocket_one/body.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-10 15:52:16
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2025-01-04 09:34:49
LastEditTime: 2025-01-04 17:09:54
FilePath: /MineStudio/minestudio/models/rocket_one/body.py
'''
import torch
Expand All @@ -21,7 +21,7 @@
class RocketPolicy(MinePolicy, PyTorchModelHubMixin):

def __init__(self,
backbone: str = 'efficientnet_b0.ra_in1k',
backbone: str = 'timm/vit_base_patch16_224.dino',
hiddim: int = 1024,
num_heads: int = 8,
num_layers: int = 4,
Expand Down
25 changes: 17 additions & 8 deletions minestudio/simulator/entry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-11 05:20:17
LastEditors: caishaofei-mus1 1744260356@qq.com
LastEditTime: 2024-12-23 19:44:25
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2025-01-05 09:28:33
FilePath: /MineStudio/minestudio/simulator/entry.py
'''

Expand Down Expand Up @@ -61,18 +61,19 @@ def download_engine():
zip_ref.extractall(local_dir)
os.remove(os.path.join(local_dir, 'engine.zip'))

def check_engine():
def check_engine(skip_confirmation=False):
if not os.path.exists(os.path.join(get_mine_studio_dir(), "engine", "build", "libs", "mcprec-6.13.jar")):
response = input("Detecting missing simulator engine, do you want to download it from huggingface (Y/N)?\n")
if response == 'Y' or response == 'y':
if skip_confirmation:
download_engine()
else:
exit(0)
response = input("Detecting missing simulator engine, do you want to download it from huggingface (Y/N)?\n")
if response == 'Y' or response == 'y':
download_engine()
else:
exit(0)

class MinecraftSim(gymnasium.Env):

check_engine()

def __init__(
self,
action_type: Literal['env', 'agent'] = 'agent', # the style of the action space
Expand All @@ -87,6 +88,7 @@ def __init__(
**kwargs
) -> Any:
super().__init__()
check_engine()
self.obs_size = obs_size
self.action_type = action_type
self.render_size = render_size
Expand Down Expand Up @@ -248,6 +250,13 @@ def observation_space(self) -> spaces.Dict:

if __name__ == '__main__':
# test if the simulator works
parser = argparse.ArgumentParser()
parser.add_argument('-y', '--yes', action='store_true', help='Skip confirmation', default=False)
args = parser.parse_args()

if args.yes:
check_engine(skip_confirmation=True)

from minestudio.simulator.callbacks import SpeedTestCallback
sim = MinecraftSim(
action_type="env",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@ devices: 8
batch_size: 8
num_workers: 6
prefetch_factor: 4
split_ratio: 0.95
split_ratio: 0.90
learning_rate: 0.00004
weight_decay: 0.001
warmup_steps: 2000
save_freq: 10000
ckpt_path: null
loss_scale: 0.01 # does not matter, since AdamW will be used
bc_weight: 1.0
kl_div_weight: 0.01
shuffle_episodes: False
kl_div_weight: 0.001
shuffle_episodes: True
episode_continuous_batch: False

model:
backbone: 'vit_base_patch32_clip_224.openai'
backbone: 'timm/vit_large_patch14_clip_224.openai'
# backbone: 'vit_base_patch32_clip_224.openai'
hiddim: 1024
freeze_backbone: True
video_encoder_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "minestudio"
version="1.0.3"
version="1.0.5"
description = "A simple and efficient Minecraft development kit for AI research."
dependencies = [
"av",
Expand Down

0 comments on commit 51d5b7b

Please sign in to comment.