Skip to content

Commit

Permalink
fix: examples remove app context at the module level (apache#15546)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored Jul 6, 2021
1 parent 08bda27 commit 0af5a3d
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 70 deletions.
7 changes: 4 additions & 3 deletions superset/examples/bart_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from superset import db
from superset.utils.core import get_example_database

from .helpers import get_example_data, TBL
from .helpers import get_example_data, get_table_connector_registry


def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
Expand Down Expand Up @@ -53,9 +53,10 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
)

print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl = table(table_name=tbl_name)
tbl.description = "BART lines"
tbl.database = database
db.session.merge(tbl)
Expand Down
33 changes: 19 additions & 14 deletions superset/examples/birth_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from typing import Dict, List, Tuple, Union

import pandas as pd
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import DateTime, String
from sqlalchemy.sql import column

from superset import db, security_manager
from superset import app, db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.exceptions import NoDataException
Expand All @@ -32,22 +33,24 @@
from superset.utils.core import get_example_database

from .helpers import (
config,
get_example_data,
get_slice_json,
get_table_connector_registry,
merge_slice,
misc_dash_slices,
TBL,
update_slice_ids,
)

admin = security_manager.find_user("admin")
if admin is None:
raise NoDataException(
"Admin user does not exist. "
"Please, check if test users are properly loaded "
"(`superset load_test_users`)."
)

def get_admin_user() -> User:
admin = security_manager.find_user("admin")
if admin is None:
raise NoDataException(
"Admin user does not exist. "
"Please, check if test users are properly loaded "
"(`superset load_test_users`)."
)
return admin


def gen_filter(
Expand Down Expand Up @@ -103,10 +106,11 @@ def load_birth_names(
if not only_metadata and (not table_exists or force):
load_data(tbl_name, database, sample=sample)

obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
print(f"Creating table [{tbl_name}] reference")
obj = TBL(table_name=tbl_name)
obj = table(table_name=tbl_name)
db.session.add(obj)

_set_table_metadata(obj, database)
Expand Down Expand Up @@ -170,13 +174,14 @@ def create_slices(
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"groupby": [],
"row_limit": config["ROW_LIMIT"],
"row_limit": app.config["ROW_LIMIT"],
"since": "100 years ago",
"until": "now",
"viz_type": "table",
"markup_type": "markdown",
}

admin = get_admin_user()
if admin_owner:
slice_props = dict(
datasource_id=tbl.id,
Expand Down Expand Up @@ -503,7 +508,7 @@ def create_slices(

def create_dashboard(slices: List[Slice]) -> Dashboard:
print("Creating a dashboard")

admin = get_admin_user()
dash = db.session.query(Dashboard).filter_by(slug="births").first()
if not dash:
dash = Dashboard()
Expand Down
7 changes: 4 additions & 3 deletions superset/examples/country_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from .helpers import (
get_example_data,
get_slice_json,
get_table_connector_registry,
merge_slice,
misc_dash_slices,
TBL,
)


Expand Down Expand Up @@ -73,9 +73,10 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
print("-" * 80)

print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name=tbl_name)
obj = table(table_name=tbl_name)
obj.main_dttm_col = "dttm"
obj.database = database
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
Expand Down
23 changes: 14 additions & 9 deletions superset/examples/deck.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice

from .helpers import get_slice_json, merge_slice, TBL, update_slice_ids
from .helpers import (
get_slice_json,
get_table_connector_registry,
merge_slice,
update_slice_ids,
)

COLOR_RED = {"r": 205, "g": 0, "b": 3, "a": 0.82}
POSITION_JSON = """\
Expand Down Expand Up @@ -170,7 +175,8 @@
def load_deck_dash() -> None:
print("Loading deck.gl dashboard")
slices = []
tbl = db.session.query(TBL).filter_by(table_name="long_lat").first()
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name="long_lat").first()
slice_data = {
"spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
"color_picker": COLOR_RED,
Expand Down Expand Up @@ -317,7 +323,7 @@ def load_deck_dash() -> None:
slices.append(slc)

polygon_tbl = (
db.session.query(TBL).filter_by(table_name="sf_population_polygons").first()
db.session.query(table).filter_by(table_name="sf_population_polygons").first()
)
slice_data = {
"datasource": "11__table",
Expand Down Expand Up @@ -449,7 +455,10 @@ def load_deck_dash() -> None:
slice_name="Arcs",
viz_type="deck_arc",
datasource_type="table",
datasource_id=db.session.query(TBL).filter_by(table_name="flights").first().id,
datasource_id=db.session.query(table)
.filter_by(table_name="flights")
.first()
.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
Expand Down Expand Up @@ -498,7 +507,7 @@ def load_deck_dash() -> None:
slice_name="Path",
viz_type="deck_path",
datasource_type="table",
datasource_id=db.session.query(TBL)
datasource_id=db.session.query(table)
.filter_by(table_name="bart_lines")
.first()
.id,
Expand All @@ -524,7 +533,3 @@ def load_deck_dash() -> None:
dash.slices = slices
db.session.merge(dash)
db.session.commit()


if __name__ == "__main__":
load_deck_dash()
12 changes: 9 additions & 3 deletions superset/examples/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from superset.models.slice import Slice
from superset.utils import core as utils

from .helpers import get_example_data, merge_slice, misc_dash_slices, TBL
from .helpers import (
get_example_data,
get_table_connector_registry,
merge_slice,
misc_dash_slices,
)


def load_energy(
Expand All @@ -52,9 +57,10 @@ def load_energy(
)

print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl = table(table_name=tbl_name)
tbl.description = "Energy consumption"
tbl.database = database

Expand Down
7 changes: 4 additions & 3 deletions superset/examples/flights.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from superset import db
from superset.utils import core as utils

from .helpers import get_example_data, TBL
from .helpers import get_example_data, get_table_connector_registry


def load_flights(only_metadata: bool = False, force: bool = False) -> None:
Expand Down Expand Up @@ -57,9 +57,10 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
index=False,
)

tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl = table(table_name=tbl_name)
tbl.description = "Random set of flights in the US"
tbl.database = database
db.session.merge(tbl)
Expand Down
12 changes: 5 additions & 7 deletions superset/examples/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,19 @@

from superset import app, db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models import core as models
from superset.models.slice import Slice

BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/"

# Shortcuts
DB = models.Database
misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboard

TBL = ConnectorRegistry.sources["table"]

config = app.config
def get_table_connector_registry() -> Any:
return ConnectorRegistry.sources["table"]

EXAMPLES_FOLDER = os.path.join(config["BASE_DIR"], "examples")

misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboard
def get_examples_folder() -> str:
return os.path.join(app.config["BASE_DIR"], "examples")


def update_slice_ids(layout_dict: Dict[Any, Any], slices: List[Slice]) -> None:
Expand Down
7 changes: 4 additions & 3 deletions superset/examples/long_lat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from .helpers import (
get_example_data,
get_slice_json,
get_table_connector_registry,
merge_slice,
misc_dash_slices,
TBL,
)


Expand Down Expand Up @@ -82,9 +82,10 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
print("-" * 80)

print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name=tbl_name)
obj = table(table_name=tbl_name)
obj.main_dttm_col = "datetime"
obj.database = database
db.session.merge(obj)
Expand Down
12 changes: 6 additions & 6 deletions superset/examples/multiformat_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@
import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, String

from superset import db
from superset import app, db
from superset.models.slice import Slice
from superset.utils.core import get_example_database

from .helpers import (
config,
get_example_data,
get_slice_json,
get_table_connector_registry,
merge_slice,
misc_dash_slices,
TBL,
)


Expand Down Expand Up @@ -75,9 +74,10 @@ def load_multiformat_time_series(
print("-" * 80)

print(f"Creating table [{tbl_name}] reference")
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name=tbl_name)
obj = table(table_name=tbl_name)
obj.main_dttm_col = "ds"
obj.database = database
dttm_and_expr_dict: Dict[str, Tuple[Optional[str], None]] = {
Expand Down Expand Up @@ -105,7 +105,7 @@ def load_multiformat_time_series(
slice_data = {
"metrics": ["count"],
"granularity_sqla": col.column_name,
"row_limit": config["ROW_LIMIT"],
"row_limit": app.config["ROW_LIMIT"],
"since": "2015",
"until": "2016",
"viz_type": "cal_heatmap",
Expand Down
7 changes: 4 additions & 3 deletions superset/examples/paris.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from superset import db
from superset.utils import core as utils

from .helpers import get_example_data, TBL
from .helpers import get_example_data, get_table_connector_registry


def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None:
Expand Down Expand Up @@ -50,9 +50,10 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
)

print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl = table(table_name=tbl_name)
tbl.description = "Map of Paris"
tbl.database = database
db.session.merge(tbl)
Expand Down
16 changes: 11 additions & 5 deletions superset/examples/random_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
import pandas as pd
from sqlalchemy import DateTime, String

from superset import db
from superset import app, db
from superset.models.slice import Slice
from superset.utils import core as utils

from .helpers import config, get_example_data, get_slice_json, merge_slice, TBL
from .helpers import (
get_example_data,
get_slice_json,
get_table_connector_registry,
merge_slice,
)


def load_random_time_series_data(
Expand Down Expand Up @@ -54,9 +59,10 @@ def load_random_time_series_data(
print("-" * 80)

print(f"Creating table [{tbl_name}] reference")
obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = TBL(table_name=tbl_name)
obj = table(table_name=tbl_name)
obj.main_dttm_col = "ds"
obj.database = database
db.session.merge(obj)
Expand All @@ -66,7 +72,7 @@ def load_random_time_series_data(

slice_data = {
"granularity_sqla": "day",
"row_limit": config["ROW_LIMIT"],
"row_limit": app.config["ROW_LIMIT"],
"since": "2019-01-01",
"until": "2019-02-01",
"metrics": ["count"],
Expand Down
Loading

0 comments on commit 0af5a3d

Please sign in to comment.