diff --git a/urbansim/urbanchoice/mnl.py b/urbansim/urbanchoice/mnl.py index f3b2b30a..463eb252 100644 --- a/urbansim/urbanchoice/mnl.py +++ b/urbansim/urbanchoice/mnl.py @@ -36,8 +36,7 @@ def mnl_probs(data, beta, numalts): utilities.reshape(numalts, utilities.size() // numalts) # https://stats.stackexchange.com/questions/304758/softmax-overflow - if clamp: - utilities.mat -= utilities.mat.max(0) + utilities = utilities.subtract(utilities.max(0)) exponentiated_utility = utilities.exp(inplace=True) if clamp: diff --git a/urbansim/urbanchoice/pmat.py b/urbansim/urbanchoice/pmat.py index d25148b2..96e1e889 100644 --- a/urbansim/urbanchoice/pmat.py +++ b/urbansim/urbanchoice/pmat.py @@ -79,6 +79,12 @@ def cumsum(self, axis): # elif self.typ == 'cuda': # return PMAT(misc.cumsum(self.mat,axis=axis)) + def max(self, axis): + if self.typ == 'numpy': + return PMAT(np.max(self.mat, axis=axis)) + elif self.typ == 'cuda': + return PMAT(self.mat.max(axis=axis)) + def argmax(self, axis): if self.typ == 'numpy': return PMAT(np.argmax(self.mat, axis=axis))