Skip to content

Commit

Permalink
by hand flake8 corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Mar 25, 2024
1 parent 38aa58b commit dae8cf6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 198 deletions.
191 changes: 4 additions & 187 deletions src/scripts/DeepEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def parse_args():
type=str,
nargs="?",
default="/repo/embargo",
help="Butler Repository path from which data is transferred. Input str. Default = '/repo/embargo'",
help="Butler Repository path from which data is transferred. \
Input str. Default = '/repo/embargo'",
)
parser.add_argument(
"torepo",
Expand Down Expand Up @@ -86,193 +87,9 @@ def parse_args():
# Define embargo and destination butler
# If move is true, then you'll need write
# permissions from the fromrepo (embargo)
butler = Butler(namespace.fromrepo, writeable=namespace.move)
registry = butler.registry
dest_butler = Butler(namespace.torepo, writeable=True)
dest_registry = dest_butler.registry
datasetTypeList = namespace.datasettype
print("whats the datasettypelist in here", datasetTypeList)
collections = namespace.collections
move = namespace.move
dest_uri_prefix = namespace.desturiprefix
# Dataset to move
dataId = {"instrument": namespace.instrument}
# Define embargo period
embargo_period = astropy.time.TimeDelta(
namespace.embargohours * 3600.0, format="sec"
)
if namespace.nowtime != "now":
now = astropy.time.Time(namespace.nowtime, scale="tai", format="iso")
else:
now = astropy.time.Time.now().tai

dest_butler = namespace.torepo
if namespace.log == "True":
CliLog.initLog(longlog=True)
# CliLog.initLog(longlog=True)
logger = logging.getLogger("lsst.transfer.embargo")
logger.info("from path: %s", namespace.fromrepo)
logger.info("to path: %s", namespace.torepo)
# the timespan object defines a "forbidden" region of time
# starting at the nowtime minus the embargo period
# and terminating in anything in the future
# this forbidden timespan will be de-select
# for moving any exposure that overlaps with it
# documentation here:
# https://community.lsst.org/t/constructing-a-where-for-query-dimension-records/6478
timespan_embargo = Timespan(now - embargo_period, None)
# The Dimensions query
# If (now - embargo period, now) does not overlap
# with observation time interval: move
# Else: don't move
# Save data Ids of these observations into a list
datalist_exposure = []
collections_exposure = []
datalist_no_exposure = []
collections_no_exposure = []

for i, dtype in enumerate(datasetTypeList):
if any(
dim in ["exposure", "visit"]
for dim in registry.queryDatasetTypes(dtype)[0].dimensions.names
):
datalist_exposure.append(dtype)
collections_exposure.append(collections[i])
else:
datalist_no_exposure.append(dtype)
collections_no_exposure.append(collections[i])

# sort out which dtype goes into which list
if namespace.log == "True":
logger.info("datalist_exposure to move: %s", datalist_exposure)
logger.info("datalist_no_exposure to move: %s", datalist_no_exposure)

# because some dtypes don't have an exposure dimension
# we will need a different option to move those
# ie deepcoadds
# but first we will move all dtypes that have an
# exposure or visit dimension (ie calexp and raw)

if datalist_exposure: # if there is anything in the list
# first, run all of the exposure types through
outside_embargo = [
dt.id
for dt in registry.queryDimensionRecords(
"exposure",
dataId=dataId,
datasets=datalist_exposure,
collections=collections_exposure,
where="NOT exposure.timespan OVERLAPS\
timespan_embargo",
bind={"timespan_embargo": timespan_embargo},
)
]
# Query the DataIds after embargo period
datasetRefs_exposure = registry.queryDatasets(
datalist_exposure,
dataId=dataId,
collections=collections_exposure,
where="exposure.id IN (exposure_ids)",
bind={"exposure_ids": outside_embargo},
).expanded()

if namespace.log == "True":
ids_to_move = [dt.dataId.mapping["exposure"] for dt in datasetRefs_exposure]
logger.info("exposure ids to move: %s", ids_to_move)

# raw dtype requires special handling for the transfer,
# so separate by dtype:
for dtype in datalist_exposure:
if dtype == "raw":
# first check that the destination uri is defined
assert (
dest_uri_prefix
), f"dest_uri_prefix needs to be specified to transfer raw datatype, {dest_uri_prefix}"
# define a new filedataset_list using URIs
dest_uri = lsst.resources.ResourcePath(dest_uri_prefix)
source_uri = butler.get_many_uris(datasetRefs_exposure)
filedataset_list = []
for key, value in source_uri.items():
source_path_uri = value[0]
source_path = source_path_uri.relative_to(value[0].root_uri())
new_dest_uri = dest_uri.join(source_path)
if os.path.exists(source_path):
if namespace.log == "True":
logger.info("source path uri already exists")
else:
new_dest_uri.transfer_from(source_path_uri, transfer="copy")
filedataset_list.append(
lsst.daf.butler.FileDataset(new_dest_uri, key)
)

# register datasettype and collection run only once
try:
dest_butler.registry.registerDatasetType(
list(datasetRefs_exposure)[0].datasetType
)
dest_butler.registry.registerRun(list(datasetRefs_exposure)[0].run)

# ingest to the destination butler
dest_butler.transfer_dimension_records_from(
butler, datasetRefs_exposure
)
dest_butler.ingest(*filedataset_list, transfer="direct")
except IndexError:
# this will be thrown if nothing is being moved
if namespace.log == "True":
logger.info("nothing in datasetRefs_exposure")

else:
dest_butler.transfer_from(
butler,
source_refs=datasetRefs_exposure,
transfer="copy",
skip_missing=True,
register_dataset_types=True,
transfer_dimensions=True,
)
if namespace.log == "True":
ids_moved = [
dt.dataId.mapping["exposure"]
for dt in dest_registry.queryDatasets(
datasetType=datalist_exposure, collections=collections_exposure
)
]
logger.info("exposure ids moved: %s", ids_moved)
if datalist_no_exposure:
# this is for datatypes that don't have an exposure
# or visit dimension
# ie deepcoadds need to be queried using an ingest
# date keyword
datasetRefs_no_exposure = registry.queryDatasets(
datasetType=datalist_no_exposure,
collections=collections_no_exposure,
where="ingest_date <= timespan_embargo_begin",
bind={"timespan_embargo_begin": timespan_embargo.begin},
)
if namespace.log == "True":
ids_to_move = [dt.id for dt in datasetRefs_no_exposure]
logger.info("ingest ids to move: %s", ids_to_move)
dest_butler.transfer_from(
butler,
source_refs=datasetRefs_no_exposure,
transfer="copy",
skip_missing=True,
register_dataset_types=True,
transfer_dimensions=True,
)
if namespace.log == "True":
ids_moved = [
dt.id
for dt in dest_registry.queryDatasets(
datasetType=datalist_no_exposure,
collections=collections_no_exposure,
)
]
logger.info("ingest ids moved: %s", ids_moved)

if move == "True":
# prune both datasettypes
# is there a way to do this at the same time?
if datalist_exposure:
butler.pruneDatasets(refs=datasetRefs_exposure, unstore=True, purge=True)
if datalist_no_exposure:
butler.pruneDatasets(refs=datasetRefs_no_exposure, unstore=True, purge=True)
7 changes: 5 additions & 2 deletions src/scripts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def model_setup_DE(loss_type, DEVICE): # , INIT_LR=0.001):
if loss_type == "var_loss":
# model = de_var().to(DEVICE)
Layer = MuVarLayer
lossFn = torch.nn.GaussianNLLLoss(full=False, eps=1e-06, reduction="mean")
lossFn = torch.nn.GaussianNLLLoss(full=False,
eps=1e-06,
reduction="mean")
if loss_type == "bnll_loss":
# model = de_var().to(DEVICE)
Layer = MuVarLayer
Expand Down Expand Up @@ -182,7 +184,8 @@ def loss_sder(y, y_pred, coeff):
)
u_ep = 1 / np.sqrt(nu.detach().numpy())

return torch.mean(torch.log(var) + (1.0 + coeff * nu) * error**2 / var), u_al, u_ep
return torch.mean(torch.log(var) + (1.0 + coeff * nu) * error**2 / var), \
u_al, u_ep


# from martius lab
Expand Down
37 changes: 28 additions & 9 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def train_DER(
color="black",
)
else:
ax1.scatter(y, pred[:, 0].flatten().detach().numpy(), color="grey")
ax1.scatter(y,
pred[:, 0].flatten().detach().numpy(),
color="grey")
loss_this_epoch.append(loss[0].item())

# zero out the gradients
Expand Down Expand Up @@ -168,7 +170,8 @@ def train_DER(
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
path_to_model + "/" + str(model_name) + "_epoch_" + str(epoch) + ".pt",
path_to_model + "/" + str(model_name)
+ "_epoch_" + str(epoch) + ".pt",
)
endTime = time.time()
if verbose:
Expand Down Expand Up @@ -287,7 +290,9 @@ def train_DE(
if loss_type == "no_var_loss":
loss = lossFn(pred.flatten(), y)
if loss_type == "var_loss":
loss = lossFn(pred[:, 0].flatten(), y, pred[:, 1].flatten())
loss = lossFn(pred[:, 0].flatten(),
y,
pred[:, 1].flatten())
if loss_type == "bnll_loss":
"""
if e/EPOCHS < 0.2:
Expand Down Expand Up @@ -318,7 +323,10 @@ def train_DE(
except ValueError:
pass
loss = lossFn(
pred[:, 0].flatten(), pred[:, 1].flatten(), y, beta=beta_epoch
pred[:, 0].flatten(),
pred[:, 1].flatten(),
y,
beta=beta_epoch
)
if plot or savefig:
if (e % (EPOCHS - 1) == 0) and (e != 0):
Expand Down Expand Up @@ -393,7 +401,10 @@ def train_DE(
# best_weights = copy.deepcopy(model.state_dict())
# print('validation loss', mse)
if (plot or savefig) and (e % (EPOCHS - 1) == 0) and (e != 0):
ax1.plot(range(0, 1000), range(0, 1000), color="black", ls="--")
ax1.plot(range(0, 1000),
range(0, 1000),
color="black",
ls="--")
if loss_type == "no_var_loss":
ax1.scatter(
y_val,
Expand Down Expand Up @@ -451,7 +462,9 @@ def train_DE(
xy=(0.73, 0.1),
xycoords="axes fraction",
bbox=dict(
boxstyle="round,pad=0.5", facecolor="lightgrey", alpha=0.5
boxstyle="round,pad=0.5",
facecolor="lightgrey",
alpha=0.5
),
)

Expand All @@ -466,7 +479,9 @@ def train_DE(
xy=(0.73, 0.1),
xycoords="axes fraction",
bbox=dict(
boxstyle="round,pad=0.5", facecolor="lightgrey", alpha=0.5
boxstyle="round,pad=0.5",
facecolor="lightgrey",
alpha=0.5
),
)
ax1.set_ylabel("Prediction")
Expand Down Expand Up @@ -605,9 +620,13 @@ def train_DE(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_source", type=str, help="Data used to train the model")
parser.add_argument("--data_source",
type=str,
help="Data used to train the model")
parser.add_argument(
"--n_epochs", type=int, help="Integer number of epochs to train the model"
"--n_epochs",
type=int,
help="Integer number of epochs to train the model"
)

args = parser.parse_args()
Expand Down

0 comments on commit dae8cf6

Please sign in to comment.