From b1d56ff2503900d30d5996c40b90bbf2c49b3d2f Mon Sep 17 00:00:00 2001 From: phosseini Date: Mon, 4 Nov 2024 00:48:44 -0800 Subject: [PATCH] adding test file for converter --- tests/test_crest.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_crest.py diff --git a/tests/test_crest.py b/tests/test_crest.py new file mode 100644 index 0000000..418ee07 --- /dev/null +++ b/tests/test_crest.py @@ -0,0 +1,26 @@ +import os +import sys + +path = os.getcwd() +sys.path.append('{}/src'.format('/'.join(path.split('/')[:-1]))) + +import unittest +from crest import Converter + + +class TestCREST(unittest.TestCase): + converter = Converter() + + def test_converter(self): + df, mis = self.converter.convert2crest(dataset_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9], save_file=True) + + print("samples: " + str(len(df))) + print("+ causal: {}".format(len(df.loc[df["label"] == 1]))) + print("- non-causal: {}".format(len(df.loc[df["label"] == 0]))) + print("train: {}".format(len(df.loc[df["split"] == 0]))) + print("dev: {}".format(len(df.loc[df["split"] == 1]))) + print("test: {}".format(len(df.loc[df["split"] == 2]))) + + +if __name__ == '__main__': + unittest.main()