Skip to content

Commit

Permalink
Fix SVD sign instability
Browse files Browse the repository at this point in the history
Signed-off-by: Karthik Uppuluri <karthik.uppuluri@fmr.com>
  • Loading branch information
kuppulur authored Dec 4, 2024
1 parent 1565a97 commit 7de9ff0
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,24 @@ def _reset_seed(self, seed=1234):
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def _test_fit_transform(self, tw_model, expected, atol=1e-6):
def _test_fit_transform(self, tw_model, expected, svd=False, atol=1e-6):
predicted = tw_model.fit_transform(docs)
torch.set_printoptions(precision=10)
if not torch.allclose(predicted, expected.to(device), atol=atol):
print(predicted)
print(expected)
self.assertTrue(torch.allclose(predicted, expected.to(device), atol=atol))
# torch.set_printoptions(precision=10)
if svd:
self.assertTrue(torch.allclose(np.abs(predicted), np.abs(expected.to(device)), atol=atol))
else:
self.assertTrue(torch.allclose(predicted, expected.to(device), atol=atol))

def _test_fit_before_transform(self, tw_model, expected, atol=1e-6):
def _test_fit_before_transform(self, tw_model, expected, svd=False, atol=1e-6):
tw_model.fit(docs)
# torch.set_printoptions(precision=10)
# print(tw_model.transform(docs))
self.assertTrue(torch.allclose(tw_model.transform(docs), expected.to(device), atol=atol))
self.assertTrue(torch.allclose(tw_model(docs), expected.to(device), atol=atol))
if svd:
self.assertTrue(torch.allclose(np.abs(tw_model.transform(docs)), np.abs(expected.to(device)), atol=atol))
self.assertTrue(torch.allclose(np.abs(tw_model(docs)), np.abs(expected.to(device)), atol=atol))
else:
self.assertTrue(torch.allclose(tw_model.transform(docs), expected.to(device), atol=atol))
self.assertTrue(torch.allclose(tw_model(docs), expected.to(device), atol=atol))

def _get_test_path(self, *names):
cwd = os.getcwd()
Expand Down

0 comments on commit 7de9ff0

Please sign in to comment.