import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import re
import sys
from .functions import *
import torch.fx

grayscale = torchvision.transforms.Grayscale(num_output_channels=1)


def convert_data_for_quaternion(batch):
    """
    converts batches of RGB images in 4 channels for QNNs
    """
    assert all(batch[i][0].size(0) == 3 for i in range(len(batch)))
    inputs, labels = [], []
    for i in range(len(batch)):
        inputs.append(torch.cat([batch[i][0], grayscale(batch[i][0])], 0))
        labels.append(batch[i][1])

    return torch.stack(inputs), torch.LongTensor(labels)


# does not find an application yet
def apply_quaternion_gradient(model, layers):
    """
    hooks real-valued gradients and transforms them into one for 
    quaternion gradient descent
    
    @type model: nn.Module
    """

    for n, ((_, layer), parameter) in enumerate(zip(model.named_children(), model.parameters())):

        layer_name = re.match("^\w+", str(layer)).group()
        if layer_name in layers and len(parameter.shape) > 1 and n != 1:
            parameter.register_hook(to_conj)

    return model


@torch.fx.wrap
def check_shapes(x):
    if x.dim() in [3, 5]:
        x = torch.cat([*x.chunk()], 2).squeeze()
    return x


def convert_to_quaternion(Net, verbose=False, spinor=False):
    """
    converts a real_valued initialized Network to a quaternion one
    
    @type Net: nn.Module
    @type verbose: bool
    @type spinor: bool
    """
    last_module = len([mod for mod in Net.children()])
    layers = ["Linear", "Conv1d", "Conv2d", "Conv3d",
              "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d"]
    for n, (name, layer) in enumerate(Net.named_children()):
        layer_name = re.match("^\w+", str(layer)).group()

        if n != last_module - 1:
            if layer_name in layers[1:]:

                params = re.findall("(?<!\w)\d+(?<=\w)", str(layer))
                in_features, out_features, kernel_size, stride = \
                    int(params[0]), int(params[1]), (int(params[2]), int(params[3])), (int(params[4]), int(params[5]))

                assert in_features % 4 == 0, "number of in_channels must be divisible by 4"
                assert out_features % 4 == 0, "number of out_channels must be divisible by 4"

                init_func = initialize_conv
                args = (in_features // 4, out_features // 4, kernel_size)

            elif layer_name == layers[0]:

                params = re.findall("(?<==)\w+", str(layer))
                in_features, out_features, bias = int(params[0]), int(params[1]), bool(params[2])

                assert in_features % 4 == 0, "number of in_channels must be divisible by 4"
                assert out_features % 4 == 0, "number of out_channels must be divisible by 4"

                init_func = initialize_linear
                args = (in_features // 4, out_features // 4)

            else:
                continue

            quaternion_weight = init_func(*args)

            if spinor:
                weight = quaternion_weight._real_rot_repr
            else:
                weight = quaternion_weight._real_repr

            getattr(Net, name).weight = nn.Parameter(weight)
            if getattr(Net, name).bias != None:
                getattr(Net, name).bias = nn.Parameter(torch.zeros(out_features))

            traced = torch.fx.symbolic_trace(layer)

            for node in traced.graph.nodes:
                if node.op == 'placeholder':
                    with traced.graph.inserting_after(node):
                        new_node = traced.graph.call_function(
                            check_shapes, args=(node,))

                if any(lay in node.name for lay in ["conv", "lin"]):
                    with traced.graph.inserting_before(node):
                        all_nodes = [node for node in traced.graph.nodes]
                        new_node = traced.graph.call_function(node.target,
                                                              (all_nodes[1], *node.args[1:]), node.kwargs)
                        node.replace_all_uses_with(new_node)
                    traced.graph.erase_node(node)

                if node.op == 'output':
                    all_nodes = [node for node in traced.graph.nodes]
                    with traced.graph.inserting_before(node):
                        new_node = traced.graph.call_function(
                            Q, args=(node.prev,))
                        node.replace_all_uses_with(new_node)
                    traced.graph.erase_node(node)
                    with traced.graph.inserting_after(node):
                        new_node = traced.graph.output(node.prev, )

            if verbose:
                print("-" * 20, layer_name, "-" * 20, sep="\n")
                print(torch.fx.GraphModule(layer, traced.graph))

            traced.graph.lint()
            setattr(Net, name, torch.fx.GraphModule(layer, traced.graph))

    return Net