Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lalepee committed Oct 10, 2023
1 parent 192d431 commit caba4db
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 21 deletions.
2 changes: 0 additions & 2 deletions TwinCache_Connector/twincache_connector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) Cosmo Tech corporation.
# Licensed under the MIT license.
import os
import logging
import time

from CosmoTech_Acceleration_Library.Modelops.core.io.model_exporter import ModelExporter
from CosmoTech_Acceleration_Library.Modelops.core.io.model_metadata import ModelMetadata

logger = logging.getLogger(__name__)

Expand Down
6 changes: 2 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import cosmotech_api
import json
from cosmotech_api.api import scenario_api
from cosmotech_api.model.scenario import Scenario
from azure.identity import DefaultAzureCredential

from TwinCache_Connector.twincache_connector import TwinCacheConnector
from auth.authentication import Authentication
Expand Down Expand Up @@ -57,13 +55,13 @@ def create_connector() -> TwinCacheConnector:

def get_parametered_queries() -> list:
twin_cache_filtering_queries = os.getenv("SUBSET_QUERY")
logger.debug(f'Filtering queries receved: {twin_cache_filtering_queries}')
if twin_cache_filtering_queries:
logger.debug(f'Filtering queries receved: {twin_cache_filtering_queries}')
return twin_cache_filtering_queries.split(';')

twin_cache_filtering_queries_name = os.getenv("SCENARIO_SUBSET_QUERY_NAME")
logger.debug("twin_cache_filtering_queries_name %s", twin_cache_filtering_queries_name)
if twin_cache_filtering_queries_name:
logger.debug(f"twin_cache_filtering_queries_name: {twin_cache_filtering_queries_name}")
# get query parameters
default_cred = Authentication(os.getenv('IDENTITY_PROVIDER'))
configuration = cosmotech_api.Configuration(host=os.getenv('CSM_API_URL'),
Expand Down
20 changes: 10 additions & 10 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch
import main
import cosmotech_api
import auth


@patch('main.TwinCacheConnector')
Expand Down Expand Up @@ -37,8 +38,8 @@ def queries_var_env():


@patch('cosmotech_api.api.scenario_api.ScenarioApi')
@patch('main.DefaultAzureCredential')
def test_get_parametered_queries_with_empty_env_var(cred, mock_scenario_api, queries_var_env):
@patch('auth.authentication.Authentication.get_token')
def test_get_parametered_queries_with_empty_env_var(mock_auth, mock_scenario_api, queries_var_env):
# create scenario parameter value
scenario_parameter_values = cosmotech_api.model.scenario_run_template_parameter_value.ScenarioRunTemplateParameterValue(
parameter_id="scenario_subset", value="")
Expand All @@ -56,8 +57,8 @@ def test_get_parametered_queries_with_empty_env_var(cred, mock_scenario_api, que


@patch('cosmotech_api.api.scenario_api.ScenarioApi')
@patch('main.DefaultAzureCredential')
def test_get_parametered_queries_with_env_var(cred, mock_scenario_api, queries_var_env):
@patch('auth.authentication.Authentication.get_token')
def test_get_parametered_queries_with_env_var(mock_auth, mock_scenario_api, queries_var_env):
# create scenario parameter value
scenario_parameter_values = cosmotech_api.model.scenario_run_template_parameter_value.ScenarioRunTemplateParameterValue(
parameter_id="scenario_subset", value="[\"query1\", \"query2\"]")
Expand All @@ -75,8 +76,8 @@ def test_get_parametered_queries_with_env_var(cred, mock_scenario_api, queries_v


@patch('cosmotech_api.api.scenario_api.ScenarioApi')
@patch('main.DefaultAzureCredential')
def test_get_parametered_queries_with_too_many_matching_env_var(cred, mock_scenario_api, queries_var_env):
@patch('auth.authentication.Authentication.get_token')
def test_get_parametered_queries_with_too_many_matching_env_var(mock_auth, mock_scenario_api, queries_var_env):
# create scenario parameter value
scenario_parameter_values_1 = cosmotech_api.model.scenario_run_template_parameter_value.ScenarioRunTemplateParameterValue(
parameter_id="scenario_subset", value="[\"query1\", \"query2\"]")
Expand All @@ -102,8 +103,7 @@ def empty_queries_var_env():


@patch('cosmotech_api.api.scenario_api.ScenarioApi')
@patch('main.DefaultAzureCredential')
def test_get_parametered_queries_with_env_var_empty(cred, mock_scenario_api, empty_queries_var_env):
def test_get_parametered_queries_with_env_var_empty(mock_scenario_api, empty_queries_var_env):
# create scenario parameter value
scenario_parameter_values = cosmotech_api.model.scenario_run_template_parameter_value.ScenarioRunTemplateParameterValue(
parameter_id="scenario_subset", value="[\"query1\", \"query2\"]")
Expand All @@ -121,8 +121,8 @@ def test_get_parametered_queries_with_env_var_empty(cred, mock_scenario_api, emp


@patch('cosmotech_api.api.scenario_api.ScenarioApi')
@patch('main.DefaultAzureCredential')
def test_get_parametered_queries_with_env_var_not_found(cred, mock_scenario_api, queries_var_env):
@patch('auth.authentication.Authentication.get_token')
def test_get_parametered_queries_with_env_var_not_found(mock_auth, mock_scenario_api, queries_var_env):
# create scenario parameter value
scenario_parameter_values = cosmotech_api.model.scenario_run_template_parameter_value.ScenarioRunTemplateParameterValue(
parameter_id="scenario_noTheRightName", value="[\"query1\", \"query2\"]")
Expand Down
7 changes: 2 additions & 5 deletions tests/test_twincache.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import pytest
from unittest.mock import patch
from TwinCache_Connector.twincache_connector import TwinCacheConnector


@patch('TwinCache_Connector.twincache_connector.ModelMetadata')
@patch('TwinCache_Connector.twincache_connector.ModelExporter')
def test_twincache_main(mock_model_exporter, mock_model_metadata):
def test_twincache_main(mock_model_exporter):
twincache = TwinCacheConnector('host', 3333, 'name')
twincache.run()
twincache.m_exporter.export_all_data.assert_called_once()


@patch('TwinCache_Connector.twincache_connector.ModelMetadata')
@patch('TwinCache_Connector.twincache_connector.ModelExporter')
def test_twincache_with_queries(mock_model_exporter, mock_model_metadata):
def test_twincache_with_queries(mock_model_exporter):
twincache = TwinCacheConnector('host', 3333, 'name')
twincache.run(['query1', 'query2'])
twincache.m_exporter.export_from_queries.assert_called_once_with(['query1', 'query2'])

0 comments on commit caba4db

Please sign in to comment.