diff --git a/mpds_ml_labs/test_app.py b/mpds_ml_labs/test_app.py index 5d77b8b..a48dccf 100755 --- a/mpds_ml_labs/test_app.py +++ b/mpds_ml_labs/test_app.py @@ -8,13 +8,14 @@ from mpds_client import MPDSDataRetrieval, APIError -from prediction import prop_semantics +from prediction import prop_models from struct_utils import detect_format, poscar_to_ase, symmetrize, get_formula, sgn_to_crsystem from cif_utils import cif_to_ase +from common import API_KEY, API_ENDPOINT req = httplib2.Http() -client = MPDSDataRetrieval() +client = MPDSDataRetrieval(api_key=API_KEY, endpoint=API_ENDPOINT) def make_request(address, data={}, httpverb='POST', headers={}): @@ -59,7 +60,7 @@ def make_request(address, data={}, httpverb='POST', headers={}): raise RuntimeError(answer['error']) formulae_categ, lattices_categ = get_formula(ase_obj), sgn_to_crsystem(ase_obj.info['spacegroup'].no) - for prop_id, pdata in prop_semantics.items(): + for prop_id, pdata in prop_models.items(): try: resp = client.get_dataframe({ 'formulae': formulae_categ, @@ -67,7 +68,7 @@ def make_request(address, data={}, httpverb='POST', headers={}): 'props': pdata['name'] }) except APIError as e: - prop_semantics[prop_id]['factual'] = None + prop_models[prop_id]['factual'] = None if e.code == 1: continue else: @@ -75,13 +76,13 @@ def make_request(address, data={}, httpverb='POST', headers={}): resp['Value'] = resp['Value'].astype('float64') # to treat values out of bounds given as str resp = resp[resp['Units'] == pdata['units']] - prop_semantics[prop_id]['factual'] = np.median(resp['Value']) + prop_models[prop_id]['factual'] = np.median(resp['Value']) for prop_id, pdata in answer['prediction'].items(): print("{0:40} = {1:6}, factual {2:8} (MAE = {3:4}), {4}".format( - prop_semantics[prop_id]['name'], - pdata['value'], - prop_semantics[prop_id]['factual'] or 'absent', + prop_models[prop_id]['name'], + 'conductor' if pdata['value'] == 0 and prop_id == 'w' else pdata['value'], + prop_models[prop_id]['factual'] or 'absent', pdata['mae'], - prop_semantics[prop_id]['units'] + prop_models[prop_id]['units'] )) diff --git a/mpds_ml_labs/test_ml.py b/mpds_ml_labs/test_ml.py index 03487ae..f3f69fc 100755 --- a/mpds_ml_labs/test_ml.py +++ b/mpds_ml_labs/test_ml.py @@ -3,7 +3,7 @@ from struct_utils import detect_format, poscar_to_ase, symmetrize from cif_utils import cif_to_ase -from prediction import ase_to_ml_model, load_ml_model, prop_semantics +from prediction import ase_to_prediction, load_ml_models, prop_models from common import ML_MODELS, DATA_PATH @@ -11,16 +11,19 @@ if sys.argv[1:]: inputs = [f for f in sys.argv[1:] if os.path.isfile(f)] - models, structures = \ - [f for f in inputs if f.endswith('.pkl')], [f for f in inputs if not f.endswith('.pkl')] + models, structures = [ + f for f in inputs if f.endswith('.pkl') + ], [ + f for f in inputs if not f.endswith('.pkl') + ] if not models: models = ML_MODELS if not structures: - structures = [os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if os.path.isfile(os.path.join(DATA_PATH, f))] + structures = [os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if os.path.isfile(os.path.join(DATA_PATH, f)) and 'settings.ini' not in f] -active_ml_model = load_ml_model(models) +active_ml_models = load_ml_models(models) for fname in structures: print @@ -50,15 +53,15 @@ print(error) continue - prediction, error = ase_to_ml_model(ase_obj, active_ml_model) + prediction, error = ase_to_prediction(ase_obj, active_ml_models) if error: print(error) continue for prop_id, pdata in prediction.items(): print("{0:40} = {1:6} (MAE = {2:4}), {3}".format( - prop_semantics[prop_id]['name'], - pdata['value'], + prop_models[prop_id]['name'], + 'conductor' if pdata['value'] == 0 and prop_id == 'w' else pdata['value'], pdata['mae'], - prop_semantics[prop_id]['units'] + prop_models[prop_id]['units'] ))