#!/usr/bin/env python

# Copyright (C) 2017 Alex Nitz, Duncan Macleod
#               2022 Shichao Wu
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Generate a bank of templates using a brute force stochastic method.
"""
import numpy
import logging
import argparse
import numpy.random
from scipy.stats import gaussian_kde

import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
import pycbc.pool
from pycbc import transforms
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc.distributions import read_params_from_config
from pycbc.distributions.utils import draw_samples_from_config, prior_from_config
from pycbc.io import HFile

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--output-file', required=True,
    help='Output file name for template bank.')
parser.add_argument('--input-file',
    help='Bank to use as a starting point.')
parser.add_argument('--input-config',
    help='Draw parameters from the given configure file.')
parser.add_argument('--params',
    help='list of paramaters to use', nargs='+')
parser.add_argument('--min',
    help='list of the minimum parameter values', nargs='+', type=float)
parser.add_argument('--max',
    help='list of the maximum parameter values', nargs='+', type=float)
parser.add_argument('--approximant',  required=False,
    help='The waveform approximant to place.')
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
    help='size of waveform buffer in seconds')
parser.add_argument('--full-resolution-buffer-length', default=None, type=float,
    help='Size of the waveform buffer in seconds for generating time-domain signals at full resolution before conversion to the frequency domain.')
parser.add_argument('--max-signal-length', type= float,
                    help="When specified, it cuts the maximum length of the waveform model to the lengh provided")
parser.add_argument('--sample-rate', default=2048, type=float,
    help='sample rate in seconds')
parser.add_argument('--low-frequency-cutoff', default=20.0, type=float)
parser.add_argument('--enable-sigma-bound', action='store_true')
parser.add_argument('--tau0-threshold', type=float)
parser.add_argument('--permissive', action='store_true',
    help='Allow waveform generator to fail.')
parser.add_argument('--placement-iterations', default=1000, type=int,
    help='Specify the number of attempts the bank should make when placing points. Use this option if the bank fails to place any points.')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--tolerance', type=float)
parser.add_argument('--max-mtotal', type=float)
parser.add_argument('--min-mchirp', type=float)
parser.add_argument('--max-mchirp', type=float)
parser.add_argument('--fixed-params', type=str, nargs='*')
parser.add_argument('--fixed-values', type=float, nargs='*')
parser.add_argument('--use-cross', action='store_true')
parser.add_argument('--max-q', type=float)
parser.add_argument('--tau0-crawl', type=float)
parser.add_argument('--tau0-start', type=float)
parser.add_argument('--tau0-end', type=float)
parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0)
parser.add_argument('--nprocesses', type=int, default=1,
    help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.')
pycbc.psd.insert_psd_option_group(parser)
args = parser.parse_args()

pycbc.init_logging(args.verbose)
numpy.random.seed(args.seed)

# Read the .ini file if it's in the input.
if args.input_config is not None:
    config_parser = pycbc.types.config.InterpolatingConfigParser()
    file = open(args.input_config, 'r')
    config_parser.read_file(file)
    file.close()

    variable_args, static_args = read_params_from_config(
        config_parser, prior_section='prior',
        vargs_section='variable_params',
        sargs_section='static_params')

    if any(config_parser.get_subsections('waveform_transforms')):
        waveform_transforms = transforms.read_transforms_from_config(
                config_parser, 'waveform_transforms')
    else:
        waveform_transforms = None

    dists_joint = prior_from_config(cp=config_parser)

fdict = {}
if args.fixed_params:
    fdict = {p: v for (p, v) in zip(args.fixed_params, args.fixed_values)}

class Shrinker(object):
    def __init__(self, data):
        self.data = data

    def pop(self):
        if len(self.data) == 0:
            return None
        l = self.data[-1]
        self.data = self.data[:-1]
        return l

class TriangleBank(object):
    """ A bank of templates that uses the triangle inequality to estimate
    matches based on prior ones.
    """
    def __init__(self, p=None):
        self.waveforms = p if p is not None else []
        self.tbins = {}

    def __len__(self):
        return len(self.waveforms)

    def activelen(self):
        i = 0
        for w in self.waveforms:
            if isinstance(w, pycbc.types.FrequencySeries):
                i += 1
        return i

    def insert(self, hp):
        self.waveforms.append(hp)

        for b in [hp.tbin - 1, hp.tbin, hp.tbin + 1]:
            if b in self.tbins:
                self.tbins[b].append(len(self)-1)
            else:
                self.tbins[b] = [len(self)-1]

    def __getitem__(self, index):
        return self.waveforms[index]

    def keys(self):
        return self.waveforms[0].params.keys()

    def key(self, k):
        return numpy.array([p.params[k] for p in self.waveforms])

    def sigma_match_bound(self, sig):
        if not hasattr(self, 'sigma'):
            self.sigma = None
        if self.sigma is None or len(self.sigma) != len(self):
            self.sigma = numpy.array([h.s for h in bank.waveforms])
        return numpy.minimum(sig / self.sigma, self.sigma / sig)

    def range(self):
        if not hasattr(self, 'r'):
            self.r = None
        if self.r is None or len(self.r) != len(self):
            self.r = numpy.arange(0, len(self))
        return self.r

    def culltau0(self, threshold):
        cull = numpy.where(self.tau0() < threshold)[0]

        class dumb(object):
            pass
        for c in cull:
            d = dumb()
            d.tau0 = self.waveforms[c].tau0
            d.params = self.waveforms[c].params
            d.s = self.waveforms[c].s
            self.waveforms[c] = d

    def tau0(self):
        if not hasattr(self, 't0'):
            self.t0 = None
        if self.t0 is None or len(self.t0) != len(self):
            self.t0 = numpy.array([h.tau0 for h in self])
        return self.t0

    def __contains__(self, hp):
        mmax = 0
        mnum = 0
        #Apply sigmas maximal match.
        if args.enable_sigma_bound:
            matches = self.sigma_match_bound(hp.s)
            r = self.range()[matches > hp.threshold]
        else:
            matches = numpy.ones(len(self))
            r = self.range()

        msig = len(r)

        #Apply tau0 threshold
        if args.tau0_threshold:
            hp.tau0 = pycbc.conversions.tau0_from_mass1_mass2(
                                            hp.params['mass1'],
                                            hp.params['mass2'],
                                            args.tau0_cutoff_frequency)
            hp.tbin = int(hp.tau0 / args.tau0_threshold)

            if hp.tbin in self.tbins:
                r = numpy.array(self.tbins[hp.tbin])
            else:
                r = r[:0]

        mtau = len(r)

        # Try to do some actual matches
        inc = Shrinker(r*1)
        while 1:
            j = inc.pop()
            if j is None:
                hp.matches = matches[r]
                hp.indices = r
                logging.info("TADD MaxMatch:%0.3f Size:%i "
                             "AfterSigma:%i AfterTau0:%i Matches:%i"
                              % (mmax, len(self), msig, mtau, mnum))
                return False

            hc = self[j]
            m = hp.gen.match(hp, hc)
            matches[j] = m
            mnum += 1

            # Update bounding match values, apply triangle inequality
            maxmatches = hc.matches - m + 1.10
            update = numpy.where(maxmatches < matches[hc.indices])[0]
            matches[hc.indices[update]] = maxmatches[update]

            # Update where to calculate matches
            skip_threshold = 1 - (1 - hp.threshold) * 2.0
            inc.data = inc.data[matches[inc.data] > skip_threshold]

            if m > hp.threshold:
                return True
            if m > mmax:
                mmax = m

    def check_params(self, gen, params, threshold):
        num_added = 0
        total_num = len(tuple(params.values())[0])
        waveform_cache = []

        pool = pycbc.pool.choose_pool(args.nprocesses)
        for return_wf in pool.imap_unordered(
                wf_wrapper,
                ({k: params[k][idx] for k in params} for idx in range(total_num))
            ):
            waveform_cache += [return_wf]
        pool.close_pool()
        del pool

        for hp in waveform_cache:
            if hp is not None:
                hp.gen = gen
                hp.threshold = threshold
                if hp not in self:
                    num_added += 1
                    self.insert(hp)
            else:
                logging.info("Waveform generation failed!")
                continue

        return bank, num_added / total_num

def decimate_frequency_domain(template, target_df):
    """
    Returns a frequency-domain waveform resampled to a lower frequency resolution 
    (delta_f) by decimation.

    Parameters
    ----------
    template : pycbc.types.FrequencySeries
        The input frequency-domain signal to be decimated.
    target_df : float
        The target frequency resolution (delta_f) for the decimated signal.

    Returns
    ----------
    decimated_template : pycbc.types.FrequencySeries
        A new FrequencySeries object with the decimated data and the specified 
        target delta_f.
    """
    # Calculate the decimation factor
    decimation_factor = int(target_df / template.delta_f)

    if decimation_factor < 1:
        raise ValueError("Target delta_f must be greater than or equal to the original delta_f.")

    # Decimate the data by selecting every 'decimation_factor'-th point
    decimated_signal = template.data[::decimation_factor]

    # Create a new FrequencySeries object with the decimated data and the target delta_f
    decimated_template = pycbc.types.FrequencySeries(decimated_signal, delta_f=target_df)
    return decimated_template

class GenUniformWaveform(object):
    def __init__(self, buffer_length, sample_rate, f_lower):
        self.f_lower = f_lower
        self.delta_f = 1.0 / buffer_length
        tlen = int(buffer_length * sample_rate)
        self.flen = tlen // 2 + 1
        psd = pycbc.psd.from_cli(args, self.flen, self.delta_f, self.f_lower)
        self.kmin = int(f_lower * buffer_length)
        self.w = ((1.0 / psd[self.kmin:-1]) ** 0.5).astype(numpy.float32)
        qtilde = pycbc.types.zeros(tlen, numpy.complex64)
        q = pycbc.types.zeros(tlen, numpy.complex64)
        self.qtilde_view = qtilde[self.kmin:self.flen - 1]
        self.ifft = pycbc.fft.IFFT(qtilde, q)
        self.md = q._data[-100:]
        self.md2 = q._data[0:100]

    def generate(self, **kwds):
        kwds.update(fdict)
        if args.max_signal_length is not None:
                flow = numpy.arange(self.f_lower, 100, .1)[::-1]
                length = spa_length_in_time(mass1=kwds['mass1'], mass2=kwds['mass2'], f_lower=flow, phase_order=-1)
                maxlen = args.max_signal_length
                x = numpy.searchsorted(length, maxlen) - 1
                l = length[x]
                f = flow[x]
        else:
                f = self.f_lower

        kwds['f_lower'] = f
        if hasattr(kwds['approximant'], 'decode'):
            kwds['approximant'] = kwds['approximant'].decode()

        if kwds['approximant'] in pycbc.waveform.fd_approximants():
            if args.full_resolution_buffer_length is not None:
                # Generate the frequency-domain waveform at full frequency resolution
                high_hp, high_hc = pycbc.waveform.get_fd_waveform(delta_f=1 / args.full_resolution_buffer_length,
                                                    **kwds)
                # Decimate the generated signal to a reduced frequency resolution
                hp = decimate_frequency_domain(high_hp, 1 / args.buffer_length)
                hc = decimate_frequency_domain(high_hc, 1 / args.buffer_length)
            else:
                hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
                                                    **kwds)
            if  args.use_cross:
                hp = hc

            if 'fratio' in kwds:
                hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])

        else:
            dt = 1.0 / args.sample_rate
            hp = pycbc.waveform.get_waveform_filter(
                        pycbc.types.zeros(self.flen, dtype=numpy.complex64),
                        delta_f=self.delta_f, delta_t=dt,
                        **kwds)

        hp.resize(self.flen)
        hp = hp.astype(numpy.complex64)
        hp[self.kmin:-1] *= self.w
        s = float(1.0 / pycbc.filter.sigmasq(hp,
                                             low_frequency_cutoff=f) ** 0.5)
        hp *= s
        hp.params = kwds
        hp.view = hp[self.kmin:-1]
        hp.s = (1.0 / s) ** 2.0
        return hp

    def match(self, hp, hc):
        pycbc.filter.correlate(hp.view, hc.view, self.qtilde_view)
        self.ifft.execute()
        m = max(abs(self.md).max(), abs(self.md2).max())
        return m * 4.0 * self.delta_f

r = 0
if not args.tolerance:
    tolerance = (1 - args.minimal_match) / 10
else:
    tolerance = args.tolerance

size = int(1.0 / tolerance)

gen = GenUniformWaveform(args.buffer_length,
    args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

def wf_wrapper(p):
    try:
        hp = gen.generate(**p)
        return hp
    except Exception as e:
        print(e)
        return None

if args.input_file:
    f = HFile(args.input_file, 'r')
    params = {k: f[k][:] for k in f}
    bank, _ = bank.check_params(gen, params, args.minimal_match)


def draw(rtype):

    if rtype == 'uniform':
        if args.input_config is None:
            params = {name: numpy.random.uniform(pmin, pmax, size=size)
                      for name, pmin, pmax in zip(args.params, args.min, args.max)}
        else:
            # `draw_samples_from_config` has its own fixed seed, so must overwrite it.
            random_seed = numpy.random.randint(low=0, high=2**32-1)
            samples = draw_samples_from_config(args.input_config, size, random_seed)
            params = {name: samples[name] for name in samples.fieldnames}
            # Add `static_args` back.
            if static_args is not None:
                for k in static_args.keys():
                    params[k] = numpy.array([static_args[k]]*size)

    elif rtype == 'kde':
        trail = 300
        if trail > len(bank):
            trail = len(bank)
        p = bank.keys()
        p = [k for k in p if k not in fdict]
        p.remove('approximant')
        p.remove('f_lower')
        if args.input_config is not None:
            p = variable_args
        bdata = numpy.array([bank.key(k)[-trail:] for k in p])
        kde = gaussian_kde(bdata)
        points = kde.resample(size=size)
        params = {k: v for k, v in zip(p, points)}

        # Add `static_args` back, some transformations may need them.
        if args.input_config is not None and static_args is not None:
            for k in static_args.keys():
                params[k] = numpy.array([static_args[k]]*size)

        # Apply `waveform_transforms` defined in the .ini file to samples.
        if args.input_config is not None and waveform_transforms is not None:
            params = transforms.apply_transforms(params, waveform_transforms)

    if args.approximant is not None:
        params['approximant'] = numpy.array([args.approximant]*size)

    # Filter out stuff (kde method may also generate samples outside boundaries).
    l = None
    if args.input_config is None:
        for name, pmin, pmax in zip(args.params, args.min, args.max):
            nl = (params[name] < pmax) & (params[name] > pmin)
            l = (nl & l) if l is not None else nl

        if args.max_q:
            q =  numpy.maximum(params['mass1'] / params['mass2'], params['mass2'] / params['mass1'])
            l &= q < args.max_q

        if args.max_mtotal:
            l &= params['mass1'] + params['mass2'] < args.max_mtotal

        if args.max_mchirp:
            from pycbc.conversions import mchirp_from_mass1_mass2
            mc = mchirp_from_mass1_mass2(params['mass1'], params['mass2'])
            l &= mc < args.max_mchirp

        if args.min_mchirp:
            from pycbc.conversions import mchirp_from_mass1_mass2
            mc = mchirp_from_mass1_mass2(params['mass1'], params['mass2'])
            l &= mc > args.min_mchirp

    else:
        l = dists_joint.contains(params)

    params = {k: params[k][l] for k in params}
    return params

def cdraw(rtype, ts, te):
    from pycbc.conversions import tau0_from_mass1_mass2

    p = draw(rtype)
    if  len(p[list(p.keys())[0]]) > 0:
        t = tau0_from_mass1_mass2(p['mass1'], p['mass2'],
                                  args.tau0_cutoff_frequency)
        l = (t < te) & (t > ts)
        p = {k: p[k][l] for k in p}

    i = 0
    while len(p[list(p.keys())[0]]) < size:

        tp = draw(rtype)
        p = {k: numpy.concatenate([p[k], tp[k]]) for k in p}

        if  len(p[list(p.keys())[0]]) > 0:
            t = tau0_from_mass1_mass2(p['mass1'], p['mass2'],
                                      args.tau0_cutoff_frequency)
            l = (t < te) & (t > ts)
            p = {k: p[k][l] for k in p}

        i += 1
        if i > args.placement_iterations:
            break

    if len(p[list(p.keys())[0]]) == 0:
        return None

    return p

tau0s = args.tau0_start
tau0e = tau0s + args.tau0_crawl

go = True

region = 0
while tau0s < args.tau0_end:
    conv = 1
    r = 0
    while conv > tolerance:
        # Standard Round
        r += 1
        params = cdraw('uniform', tau0s, tau0e)
        if params is None:
            if len(bank) > 0:
                go = False
            break

        blen = len(bank)
        bank, uconv = bank.check_params(gen, params, args.minimal_match)
        logging.info("%s: Round (U): %s Size: %s conv: %s added: %s",
                     region, r, len(bank), uconv, len(bank) - blen)
        if r > 10:
            conv = uconv
        kloop = 0

        while ((kloop == 0) or (kconv / okconv) > .5) and len(bank) > 10:
            r += 1
            kloop += 1
            params = cdraw('kde', tau0s, tau0e)
            blen = len(bank)
            bank, kconv = bank.check_params(gen, params, args.minimal_match)
            logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s",
                         region, kloop, r, len(bank), kconv, len(bank) - blen)


            if uconv:
                logging.info('Ratio of convergences: %2.3f' % (kconv / (uconv)))
                logging.info('Progress: {:.0%} completed'.format(tau0e/args.tau0_end))

            if kloop == 1:
                okconv = kconv

            if kconv <= tolerance:
                conv = kconv
                break

    bank.culltau0(tau0s - args.tau0_threshold * 2.0)
    logging.info("Region Done %3.1f-%3.1f, %s stored", tau0s, tau0e, bank.activelen())
    region += 1
    tau0s += args.tau0_crawl / 2
    tau0e += args.tau0_crawl / 2

o = HFile(args.output_file, 'w')
o.attrs['minimal_match'] = args.minimal_match
for k in bank.keys():
    val = bank.key(k)
    if val.dtype.char == 'U':
        val = val.astype('bytes')
    o[k] = val