-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
567 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import numpy as np | ||
import torch | ||
import sympy | ||
ISINSTALLMATLAB = True | ||
try: | ||
import matlab | ||
except ModuleNotFoundError: | ||
ISINSTALLMATLAB = False | ||
matlab = None | ||
|
||
__all__ = ['poly',] | ||
|
||
class poly(torch.nn.Module): | ||
def __init__(self, hidden_layers, channel_num, channel_names=None, normalization_weight=None): | ||
super(poly, self).__init__() | ||
self.hidden_layers = hidden_layers | ||
self.channel_num = channel_num | ||
if channel_names is None: | ||
channel_names = list('u'+str(i) for i in range(self.channel_num)) | ||
self.channel_names = channel_names | ||
layer = [] | ||
for k in range(hidden_layers): | ||
module = torch.nn.Linear(channel_num+k,2).to(dtype=torch.float64) | ||
module.weight.data.fill_(0) | ||
module.bias.data.fill_(0) | ||
self.add_module('layer'+str(k), module) | ||
layer.append(self.__getattr__('layer'+str(k))) | ||
module = torch.nn.Linear(channel_num+hidden_layers, 1).to(dtype=torch.float64) | ||
module.weight.data.fill_(0) | ||
module.bias.data.fill_(0) | ||
self.add_module('layer_final', module) | ||
layer.append(self.__getattr__('layer_final')) | ||
self.layer = tuple(layer) | ||
nw = torch.ones(channel_num).to(dtype=torch.float64) | ||
if (not isinstance(normalization_weight, torch.Tensor)) and (not normalization_weight is None): | ||
normalization_weight = np.array(normalization_weight) | ||
normalization_weight = torch.from_numpy(normalization_weight).to(dtype=torch.float64) | ||
normalization_weight = normalization_weight.view(channel_num) | ||
nw = normalization_weight | ||
self.register_buffer('_nw', nw) | ||
@property | ||
def channels(self): | ||
channels = sympy.symbols(self.channel_names) | ||
return channels | ||
def renormalize(self, nw): | ||
if (not isinstance(nw, torch.Tensor)) and (not nw is None): | ||
nw = np.array(nw) | ||
nw = torch.from_numpy(nw) | ||
nw1 = nw.view(self.channel_num) | ||
nw1 = nw1.to(self._nw) | ||
nw0 = self._nw | ||
scale = nw0/nw1 | ||
self._nw.data = nw1 | ||
for L in self.layer: | ||
L.weight.data[:,:self.channel_num] *= scale | ||
return None | ||
def _cast2numpy(self, layer): | ||
weight,bias = layer.weight.data.cpu().numpy(), \ | ||
layer.bias.data.cpu().numpy() | ||
return weight,bias | ||
def _cast2matsym(self, layer, eng): | ||
weight,bias = self._cast2numpy(layer) | ||
weight,bias = weight.tolist(),bias.tolist() | ||
weight,bias = matlab.double(weight),matlab.double(bias) | ||
eng.workspace['weight'],eng.workspace['bias'] = weight,bias | ||
eng.workspace['weight'] = eng.eval("sym(weight,'d')") | ||
eng.workspace['bias'] = eng.eval("sym(bias,'d')") | ||
return None | ||
def _cast2symbol(self, layer): | ||
weight,bias = self._cast2numpy(layer) | ||
weight,bias = sympy.Matrix(weight),sympy.Matrix(bias) | ||
return weight,bias | ||
def _sympychop(self, o, calprec): | ||
cdict = o.expand().as_coefficients_dict() | ||
o = 0 | ||
for k,v in cdict.items(): | ||
if abs(v)>0.1**calprec: | ||
o = o+k*v | ||
return o | ||
def _matsymchop(self, o, calprec, eng): | ||
eng.eval('[c,t] = coeffs('+o+');', nargout=0) | ||
eng.eval('c = double(c);', nargout=0) | ||
eng.eval('c(abs(c)<1e-'+calprec+') = 0;', nargout=0) | ||
eng.eval(o+" = sum(sym(c, 'd').*t);", nargout=0) | ||
return None | ||
def expression(self, calprec=6, eng=None, isexpand=True): | ||
if eng is None: | ||
channels = sympy.symbols(self.channel_names) | ||
for i in range(self.channel_num): | ||
channels[i] = self._nw[i].item()*channels[i] | ||
channels = sympy.Matrix([channels,]) | ||
for k in range(self.hidden_layers): | ||
weight,bias = self._cast2symbol(self.layer[k]) | ||
o = weight*channels.transpose()+bias | ||
if isexpand: | ||
o[0] = self._sympychop(o[0], calprec) | ||
o[1] = self._sympychop(o[1], calprec) | ||
channels = list(channels)+[o[0]*o[1],] | ||
channels = sympy.Matrix([channels,]) | ||
weight,bias = self._cast2symbol(self.layer[-1]) | ||
o = (weight*channels.transpose()+bias)[0] | ||
if isexpand: | ||
o = o.expand() | ||
o = self._sympychop(o, calprec) | ||
return o | ||
else: | ||
calprec = str(calprec) | ||
eng.clear(nargout=0) | ||
eng.syms(self.channel_names, nargout=0) | ||
channels = "" | ||
for c in self.channel_names: | ||
channels = channels+" "+c | ||
eng.eval('syms'+channels,nargout=0) | ||
channels = "["+channels+"].'" | ||
eng.workspace['channels'] = eng.eval(channels) | ||
eng.workspace['nw'] = matlab.double(self._nw.data.cpu().numpy().tolist()) | ||
eng.eval("channels = channels.*nw.';", nargout=0) | ||
for k in range(self.hidden_layers): | ||
self._cast2matsym(self.layer[k], eng) | ||
eng.eval("o = weight*channels+bias';", nargout=0) | ||
eng.eval('o = o(1)*o(2);', nargout=0) | ||
if isexpand: | ||
eng.eval('o = expand(o);', nargout=0) | ||
self._matsymchop('o', calprec, eng) | ||
eng.eval('channels = [channels;o];', nargout=0) | ||
self._cast2matsym(self.layer[-1],eng) | ||
eng.eval("o = weight*channels+bias';", nargout=0) | ||
if isexpand: | ||
eng.eval("o = expand(o);", nargout=0) | ||
self._matsymchop('o', calprec, eng) | ||
return eng.workspace['o'] | ||
def coeffs(self, calprec=6, eng=None, o=None, iprint=0): | ||
if eng is None: | ||
if o is None: | ||
o = self.expression(calprec, eng=None, isexpand=True) | ||
cdict = o.as_coefficients_dict() | ||
t = np.array(list(cdict.keys())) | ||
c = np.array(list(cdict.values()), dtype=np.float64) | ||
I = np.abs(c).argsort()[::-1] | ||
t = list(t[I]) | ||
c = c[I] | ||
if iprint > 0: | ||
print(o) | ||
else: | ||
if o is None: | ||
self.expression(calprec, eng=eng, isexpand=True) | ||
else: | ||
eng.workspace['o'] = eng.expand(o) | ||
eng.eval('[c,t] = coeffs(o);', nargout=0) | ||
eng.eval('c = double(c);', nargout=0) | ||
eng.eval("[~,I] = sort(abs(c), 'descend'); c = c(I); t = t(I);", nargout=0) | ||
eng.eval('m = cell(numel(t),1);', nargout=0) | ||
eng.eval('for i=1:numel(t) m(i) = {char(t(i))}; end', nargout=0) | ||
if iprint > 0: | ||
eng.eval('disp(o)', nargout=0) | ||
t = list(eng.workspace['m']) | ||
c = np.array(eng.workspace['c']).flatten() | ||
return t,c | ||
def symboleval(self,inputs,eng=None,o=None): | ||
if isinstance(inputs, torch.Tensor): | ||
inputs = inputs.data.cpu().numpy() | ||
if isinstance(inputs, np.ndarray): | ||
inputs = list(inputs) | ||
assert len(inputs) == len(self.channel_names) | ||
if eng is None: | ||
if o is None: | ||
o = self.expression() | ||
return o.subs(dict(zip(self.channels,inputs))) | ||
else: | ||
if o is None: | ||
o = self.expression(eng=eng) | ||
channels = "[" | ||
for c in self.channel_names: | ||
channels = channels+" "+c | ||
channels = channels+"].'" | ||
eng.workspace['channels'] = eng.eval(channels) | ||
eng.workspace['tmp'] = o | ||
eng.workspace['tmpv'] = matlab.double(inputs) | ||
eng.eval("tmpresults = double(subs(tmp,channels.',tmpv));",nargout=0) | ||
return np.array(eng.workspace['tmpresults']) | ||
def forward(self, inputs): | ||
outputs = inputs*self._nw | ||
for k in range(self.hidden_layers): | ||
o = self.layer[k](outputs) | ||
outputs = torch.cat([outputs,o[...,:1]*o[...,1:]], dim=-1) | ||
outputs = self.layer[-1](outputs) | ||
return outputs[...,0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import symnet.expr as expr | ||
from symnet.preproc_input import prepare_batches | ||
from symnet.prepare_left_side import init_left_term,get_left_pool | ||
from symnet.initparams import initexpr | ||
import torch | ||
from symnet.loss import loss | ||
from symnet.preproc_output import * | ||
|
||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
from sympy import Symbol, Pow, Mul | ||
|
||
|
||
def clean_names(left_name, names: list): | ||
new_names = names.copy() | ||
idx = None | ||
if left_name in new_names: | ||
idx = new_names.index(left_name) | ||
new_names.remove(left_name) | ||
|
||
return new_names, idx | ||
|
||
|
||
def train_model(input_names, x_train, y_train, sparsity): | ||
|
||
def closure(): | ||
lbfgs.zero_grad() | ||
tloss = loss(model, y_train, x_train, block=1, sparsity=sparsity) | ||
tloss.backward() | ||
return tloss | ||
|
||
model = expr.poly(2, channel_num=len(input_names), channel_names=input_names) | ||
initexpr(model) | ||
lbfgs = torch.optim.LBFGS(params=model.parameters(), max_iter=2000, line_search_fn='strong_wolfe') | ||
model.train() | ||
lbfgs.step(closure) | ||
last_step_loss = loss(model, y_train, x_train, block=1, sparsity=sparsity) | ||
|
||
return model, last_step_loss.item() | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def right_matrices_coef(matrices, names: list[str], csym, tsym): | ||
token_matrix = {} | ||
for i in range(len(names)): | ||
token_matrix[Symbol(names[i])] = matrices[i] | ||
|
||
right_side = [] | ||
for i in range(len(csym)): | ||
total_mx = 1 | ||
if type(tsym[i]) == Mul: | ||
if tsym[i] == Mul(Symbol("u"), Symbol("du/dx2")): | ||
u_ux_ind = i | ||
lbls = tsym[i].args | ||
for lbl in lbls: | ||
if type(lbl) == Symbol: | ||
total_mx *= token_matrix.get(lbl) | ||
else: | ||
for j in range(lbl.args[1]): | ||
total_mx *= token_matrix.get(lbl.args[0]) | ||
elif type(tsym[i]) == Symbol: | ||
total_mx *= token_matrix.get(tsym[i]) | ||
elif type(tsym[i]) == Pow: | ||
for j in range(tsym[i].args[1]): | ||
total_mx *= token_matrix.get(tsym[i].args[0]) | ||
total_mx *= csym[i] | ||
right_side.append(total_mx) | ||
|
||
u_ux = 1 | ||
for lbl in (Symbol("u"), Symbol("du/dx2")): | ||
u_ux *= token_matrix.get(lbl) | ||
right_u_ux = csym[u_ux_ind] * u_ux | ||
diff1 = np.fabs((np.abs(csym[u_ux_ind]) - 1) * u_ux) | ||
return right_side, right_u_ux, u_ux | ||
|
||
|
||
def select_model1(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens): | ||
models = [] | ||
losses = [] | ||
for left_side_name in left_pool: | ||
m_input_names, idx = clean_names(left_side_name, input_names) | ||
x_train, y_train = prepare_batches(u, derivs, shape, idx, additional_tokens=additional_tokens) | ||
model, last_loss = train_model(m_input_names, x_train, y_train, sparsity) | ||
|
||
tsym, csym = model.coeffs(calprec=16) | ||
losses.append(last_loss) | ||
models.append(model) | ||
|
||
idx = losses.index(min(losses)) | ||
return models[idx], left_pool[idx] | ||
|
||
|
||
|
||
|
||
|
||
|
||
def select_model(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens): | ||
models = [] | ||
losses = [] | ||
for left_side_name in left_pool: | ||
m_input_names, idx = clean_names(left_side_name, input_names) | ||
x_train, y_train = prepare_batches(u, derivs, shape, idx, additional_tokens=additional_tokens) | ||
model, last_loss = train_model(m_input_names, x_train, y_train, sparsity) | ||
losses.append(last_loss) | ||
models.append(model) | ||
|
||
idx = losses.index(min(losses)) | ||
return models[idx], left_pool[idx] | ||
|
||
|
||
def save_fig(csym, add_left=True): | ||
distr = np.fabs(csym.copy()) | ||
if add_left: | ||
distr = np.append(distr, (distr[0] + distr[1]) / 2) | ||
distr.sort() | ||
distr = distr[::-1] | ||
|
||
fig, ax = plt.subplots(figsize=(16, 8)) | ||
ax.set_ylim(0, np.max(distr) + 0.01) | ||
sns.barplot(x=np.arange(len(distr)), y=distr, orient="v", ax=ax) | ||
plt.grid() | ||
# plt.show() | ||
plt.yticks(fontsize=50) | ||
plt.savefig(f'symnet_distr{len(distr)}.png', transparent=True) | ||
|
||
|
||
def get_csym_tsym(u, derivs, shape, input_names, pool_names, sparsity=0.1, additional_tokens=None, | ||
max_deriv_order=None): | ||
""" | ||
Can process only one variable! (u) | ||
""" | ||
|
||
left_pool = get_left_pool(max_deriv_order) | ||
model, left_side_name = select_model(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens) | ||
tsym, csym = model.coeffs(calprec=16) | ||
# save_fig(csym) | ||
pool_sym_ls = cast_to_symbols(pool_names) | ||
csym_pool_ls = get_csym_pool(tsym, csym, pool_sym_ls, left_side_name) | ||
# save_fig(np.array(csym_pool_ls), add_left=False) | ||
return dict(zip(pool_sym_ls, csym_pool_ls)), pool_sym_ls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import torch | ||
|
||
|
||
def initexpr(model): | ||
for p in model.parameters(): | ||
p.data = torch.randn(*p.shape,dtype=p.dtype,device=p.device)*1e-1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
|
||
|
||
def _sparse_loss(model): | ||
""" | ||
SymNet regularization | ||
""" | ||
loss = 0 | ||
s = 1e-3 | ||
for p in list(model.parameters()): | ||
p = p.abs() | ||
loss = loss+((p<s).to(p)*0.5/s*p**2).sum()+((p>=s).to(p)*(p-s/2)).sum() | ||
return loss | ||
|
||
|
||
def loss(model, u_left, u_right, block, sparsity): | ||
stepnum = block if block >= 1 else 1 | ||
|
||
dataloss = 0 | ||
sparseloss = _sparse_loss(model) | ||
|
||
u_der = u_left | ||
for steps in range(1, stepnum + 1): | ||
u_dertmp = model(u_right) | ||
|
||
dataloss = dataloss + \ | ||
torch.mean((u_dertmp - u_der) ** 2) | ||
# layerweight[steps-1]*torch.mean(((uttmp-u_obs[steps])/(steps*dt))**2) | ||
# ut = u_right | ||
loss = dataloss + stepnum * sparsity * sparseloss | ||
return loss |
Oops, something went wrong.