-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathMamba.py
50 lines (36 loc) · 1.55 KB
/
Mamba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba
from layers.Embed import DataEmbedding
class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len
self.d_inner = configs.d_model * configs.expand
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto"
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
self.mamba = Mamba(
d_model = configs.d_model,
d_state = configs.d_ff,
d_conv = configs.d_conv,
expand = configs.expand,
)
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
def forecast(self, x_enc, x_mark_enc):
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc
x = self.embedding(x_enc, x_mark_enc)
x = self.mamba(x)
x_out = self.out_layer(x)
x_out = x_out * std_enc + mean_enc
return x_out
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
x_out = self.forecast(x_enc, x_mark_enc)
return x_out[:, -self.pred_len:, :]
# other tasks not implemented