Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors while computing lines using recursive regression #11

Open
avyavkumar opened this issue Feb 6, 2022 · 11 comments
Open

Errors while computing lines using recursive regression #11

avyavkumar opened this issue Feb 6, 2022 · 11 comments

Comments

@avyavkumar
Copy link

Hi,

While running the algorithm for generating lines using recursive regression, I noticed the following exceptions -

R[write to console]: Error in xtx_in %*% t(v) : non-conformable arguments

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 268, in eval
    value, visible = ro.r("withVisible({%s\n})" % code)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/__init__.py", line 438, in __call__
    res = self.eval(p)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 199, in __call__
    .__call__(*args, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 125, in __call__
    res = super(Function, self).__call__(*new_args, **new_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface_lib/conversion.py", line 45, in _
    cdata = function(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface.py", line 680, in __call__
    raise embedded.RRuntimeError(_rinterface._geterrmessage())
rpy2.rinterface_lib.embedded.RRuntimeError: Error in xtx_in %*% t(v) : non-conformable arguments

I am unable to find the root cause of this exception - it would be great if a workaround or a solution can be suggested!

Another exception that comes up commonly is

Traceback (most recent call last):
  File "<ipython-input-16-1a65a5bb4343>", line 66, in <module>
    correct_soft_label_KNN, soft_label_points_KNN = classify_with_soft_label_KNN(lines, test_classifications, labeled_centroids, labeled_test_data)
  File "<ipython-input-14-443ac67806c7>", line 8, in classify_with_soft_label_KNN
    distX, distY = get_line_prototypes(line, labeled_centroids[0])
  File "<ipython-input-7-e7dbf273c244>", line 137, in get_line_prototypes
    distY[0,line], distY[1,line] = x.value[0:n], x.value[n:]
TypeError: 'NoneType' object is not subscriptable

I suspect that cvxpy is unable to find a solution in this case - which returns x as None. Again, if a workaround or a best-practise can be suggested, that would be great. Thanks!

@ilia10000
Copy link
Owner

The first error happens when the dimension of xtx and v don't match.
Can try to print the dimension of both to see if that's the case?

The second problem can happen because the optimization problem isn't very stable in its current form.
Possible solutions include relaxing some of the constraints (though this may lead to solutions that don't actually separate the classes as desired), changing the data, or using a different optimization method (e.g. fitting the soft-label prototypes using some sort of iterative method like gradient descent).

@avyavkumar
Copy link
Author

avyavkumar commented Feb 12, 2022

As per the discussion here, I modified the code to use ginv rather than solve. The following blocks were modified

xwx <- function(xtx_in, x, w) {
  v <- x[w == 0, ]
  
  if (sum(w == 0) >= ncol(x) | sum(w == 0) == 0) {
    #print(sum(w == 1))
    tryCatch({solve(t(x[w == 1, ]) %*% x[w == 1, ])},
             error = function(e) {ginv(t(x[w == 1, ]) %*% x[w == 1, ])},
             finally = {})
  } else {
    xtx_in + xtx_in %*% t(v) %*% ginv(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in
  }
  
}
add_classes <- function(data, label, classes, max_diff = 1.5) {
  
  # two furthest-apart classes in the group
  # we initially fit a line that pierces through both of their centroids
  furthest <- furthest_classes(data, label, classes)
  rest <- classes[!(classes %in% furthest)]
  
  x <- cbind(1, data[, -ncol(data)])
  y <- data[, ncol(data)]
  
  xtx_in <- ginv(t(x) %*% x, tol = sqrt(.Machine$double.eps))
  w <- ifelse(label %in% furthest, 1, 0)
  beta <- beta_w(xtx_in, x, y, w)

  while (length(rest) > 0) {
    
    # for the remaining classes, fit a regression line with it and
    # only classes currently in 'furthest' list
    beta_list <- lapply(rest, function(ii) {
      w <- ifelse(label %in% c(ii, furthest), 1, 0)
      return(beta_w(xtx_in, x, y, w))
    })
    
    # compare the distance between the regression line with the initial two furthest classes
    # and the newly fitted regression line
    distance <- sapply(beta_list, function(a) two_norm(a, beta))
    
    # stop if the smallest difference between the two regression lines is 
    # greater than the max tolerance
    if (all(distance > max_diff)) {
      rest <- integer(0)
    } else {
      
      # otherwise, include the class whose addition resulted in the smallest change
      # in the original regression line
      add <- which.min(distance)[1]
      furthest <- c(furthest, rest[add])
      rest <- rest[-add]
    }
  }
  
  return(list(group = unique(furthest), line = beta))
}

It looks like the "non-conformable" issue is arising due to changes in the second block - xtx_in <- ginv(t(x) %*% x, tol = sqrt(.Machine$double.eps)) is used instead of xtx_in <- solve(t(x) %*% x). The computed xtx_in via ginv is passed in beta <- beta_w(xtx_in, x, y, w) which calls xwx <- function(xtx_in, x, w) internally.

A point to note is that this behaviour seemingly occurs randomly - it happens more with higher number of lines needing to be generated but it can occur for 3 lines to be generated, for example. If there is a workaround for this issue regarding dimensionality, please let me know. Thanks!

@ilia10000
Copy link
Owner

ilia10000 commented Feb 18, 2022

Can you check the dimensions of xtx_in and t(v) and see if they are different in the cases where the non-conformable error comes up?

@avyavkumar
Copy link
Author

I captured some outputs below -

Running 1 / 336 with 4 classes
Fitting 1 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 16 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 5 and dim of xtx_in is 768"  
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 11 and dim of xtx_in is 768" 
Fitting 2 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 16 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 5 and dim of xtx_in is 768"  
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 11 and dim of xtx_in is 768" 
Fitting 3 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 17 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 
[1] "Dim of t(v): 1 and dim of xtx_in is 768"  
[2] "Dim of t(v): 768 and dim of xtx_in is 768"
R[write to console]: Error in xtx_in %*% t(v) : non-conformable arguments


Error in xtx_in %*% t(v) : non-conformable arguments
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 268, in eval
    value, visible = ro.r("withVisible({%s\n})" % code)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/__init__.py", line 438, in __call__
    res = self.eval(p)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 199, in __call__
    .__call__(*args, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 125, in __call__
    res = super(Function, self).__call__(*new_args, **new_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface_lib/conversion.py", line 45, in _
    cdata = function(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface.py", line 680, in __call__
    raise embedded.RRuntimeError(_rinterface._geterrmessage())
rpy2.rinterface_lib.embedded.RRuntimeError: Error in xtx_in %*% t(v) : non-conformable arguments


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-49-7991660fada5>", line 69, in <module>
    lines = [line_order_no_endpoints(centroids=labeled_centroids[0], active_classes=np.array(line)) for line in find_lines_R_multiD(dat=labeled_training_data_np, labels=labeled_training_data[1] , dims=dimensions, centroids=labeled_centroids[0], k=required_lines)]
  File "<ipython-input-9-115a5284a88a>", line 478, in find_lines_R_multiD
    get_ipython().magic('R -i df -i k -i max_diff -i dims -o result1 result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)')
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2160, in magic
    return self.run_line_magic(magic_name, magic_arg_s)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2081, in run_line_magic
    result = fn(*args,**kwargs)
  File "<decorator-gen-119>", line 2, in R
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py", line 188, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 783, in R
    raise e
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 756, in R
    text_result, result, visible = self.eval(line)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 273, in eval
    warning_or_other_msg)
rpy2.ipython.rmagic.RInterpreterError: Failed to parse and evaluate line 'result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)'.
R error message: 'Error in xtx_in %*% t(v) : non-conformable arguments'

Running 2 / 336 with 9 classes
Fitting 1 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 14 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 

It looks like in all cases (except the erroneous ones) the dimensions of t(v) are (768 x n) and for the erroneous cases, the dimensions seem to be reversed, ie: they are (1 x 768).

The complete R block I am using looks like this -

%%R
library(MASS)

xwx <- function(xtx_in, x, w) {
  v <- x[w == 0, ]
  print(paste0("Dim of t(v): ", dim(t(v)), " and dim of xtx_in is ", dim(xtx_in)))
  
  if (sum(w == 0) >= ncol(x) | sum(w == 0) == 0) {
    #print(sum(w == 1))
    tryCatch({solve(t(x[w == 1, ]) %*% x[w == 1, ])},
             error = function(e) {ginv(t(x[w == 1, ]) %*% x[w == 1, ])},
             finally = {})
  } else {
      tryCatch({xtx_in + xtx_in %*% t(v) %*% solve(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in},
               error = function(e) {xtx_in + xtx_in %*% t(v) %*% ginv(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in})
  } 
}

xwy <- function(x, y, w) {
  if (sum(w == 0) == 0) {
    t(x) %*% y
  } else {
    t(x) %*% y - t(x[w == 0, ]) %*% y[w == 0]
  }
}

beta_w <- function(xtx_in, x, y, w) {
  xwx(xtx_in, x, w) %*% xwy(x, y, w)
}

two_norm <- function(a, b) {
  sqrt(sum((a - b)^2))
}


group_classes <- function(data, label, k) {
  mu <- t(sapply(unique(label), function(ii) {
    colMeans(data[label == ii, , drop = F])
    }))
  
  mu_dist <- dist(mu)
  cluster <- cutree(hclust(mu_dist, method = "complete"), k = k)
  
  mu2 <- t(sapply(unique(cluster), function(ii) {
    colMeans(mu[cluster == ii, , drop = F])
  }))
  
  dist2 <- as.matrix(dist(mu2))
  
  jj <- 1
  while (jj <= length(unique(cluster))) {
    #print(length(unique(cluster)))
    #print(jj)
    if (table(cluster)[jj] == 1) {
      new_cluster <- which(rank(dist2[jj, ]) == 2)
      cluster[cluster == jj] <- new_cluster
    }
    jj <- jj + 1
  }
  # print(cluster)
  return(cluster)
}

furthest_classes <- function(data, label, classes) {
  mu <- t(sapply(classes, function(ii) {
    colMeans(data[label == ii,  , drop = F])
  }))
  mu_dist <- as.matrix(dist(mu))
  furthest <- which(mu_dist == max(mu_dist), arr.ind = T)[1, ]
  return(classes[furthest])
}

add_classes <- function(data, label, classes, max_diff = 1.5) {
  
  # two furthest-apart classes in the group
  # we initially fit a line that pierces through both of their centroids
  furthest <- furthest_classes(data, label, classes)
  rest <- classes[!(classes %in% furthest)]
  
  x <- cbind(1, data[, -ncol(data)])
  y <- data[, ncol(data)]
  
  xtx_in <- ginv(t(x) %*% x)
  w <- ifelse(label %in% furthest, 1, 0)
  beta <- beta_w(xtx_in, x, y, w)

  while (length(rest) > 0) {
    
    # for the remaining classes, fit a regression line with it and
    # only classes currently in 'furthest' list
    beta_list <- lapply(rest, function(ii) {
      w <- ifelse(label %in% c(ii, furthest), 1, 0)
      return(beta_w(xtx_in, x, y, w))
    })
    
    # compare the distance between the regression line with the initial two furthest classes
    # and the newly fitted regression line
    distance <- sapply(beta_list, function(a) two_norm(a, beta))
    
    # stop if the smallest difference between the two regression lines is 
    # greater than the max tolerance
    if (all(distance > max_diff)) {
      rest <- integer(0)
    } else {
      
      # otherwise, include the class whose addition resulted in the smallest change
      # in the original regression line
      add <- which.min(distance)[1]
      furthest <- c(furthest, rest[add])
      rest <- rest[-add]
    }
  }
  
  return(list(group = unique(furthest), line = beta))
}


order_classes <- function(data, label, group) {
  # first two elements in group must be the furthest away.
  # this will be the case if group comes from recursive regression
  
  if (length(group) == 1) {
    return(group)
  } else {
    temp <- sapply(group[-1], function(ii) {
      a <- colMeans(data[label == group[1], , drop = F])
      b <- colMeans(data[label == ii, , drop = F])
      return(sum((a - b)^2))
    })
    return(c(group[1], group[-1][order(temp)]))
  }
}



recursive_reg <- function(data, label, k, max_diff = 1.5, keep_all = T) {
  
  # group the class-wise centroids into k groups
  init_group <- group_classes(data, label, k)
  k_new <- length(unique(init_group))
  
  #if (k_new == 1) {
  #  val <- list(group = order_classes(data, label, 1),
  #              line = lm())
  #}
  # for each of the k groups, find a line that incorporates
  # as many of the classes in that group as possible
  val <- lapply(sort(unique(init_group)), function(ii) {
    classes <- which(init_group == ii)
    # print(classes)
    if (keep_all) {
      temp <- add_classes(data, label, classes, max_diff)
      temp$group <- order_classes(data, label, temp$group)
      return(temp$group)
    } else {
      if (length(unique(classes)) == 1) {
        return(NULL)
      } else {
        temp <- add_classes(data, label, classes, max_diff)
        temp$group <- order_classes(data, label, temp$group)
        return(temp$group)
      }
    }
    #add_classes(data, label, classes, max_diff)
    #if (length(unique(classes)) == 1) {
    #  return(NULL)
    #} else {
    #  add_classes(data, label, classes, max_diff)
    #}
  })
  
  if (keep_all) {
    # if keep_all = T, keep lines from  single classes
    return(val)
  } else {
    # If keep_all = F, filter out groups with only one class
    return(val[lengths(val) != 0])
  }
}

@ilia10000
Copy link
Owner

ilia10000 commented Feb 25, 2022

In the xwx function, the first line is
v <- x[w == 0, ]

Nam suggests this should be changed to
v <- x[w == 0, , drop = F]

Can you try that out and see if it solves this?

@avyavkumar
Copy link
Author

Hi, thanks for the suggestion, however, unfortunately still seeing these issues. Let me know if there is a workaround or if there is a direction I can investigate in - though the cases are small in number, I would still like to get to the bottom of this as it'll help provide more accurate metrics.

@avyavkumar
Copy link
Author

Hi, is there any workaround for this? Getting a good number of results with the following exception -

R[write to console]: Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments


Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 268, in eval
    value, visible = ro.r("withVisible({%s\n})" % code)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/__init__.py", line 438, in __call__
    res = self.eval(p)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 199, in __call__
    .__call__(*args, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 125, in __call__
    res = super(Function, self).__call__(*new_args, **new_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface_lib/conversion.py", line 45, in _
    cdata = function(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface.py", line 680, in __call__
    raise embedded.RRuntimeError(_rinterface._geterrmessage())
rpy2.rinterface_lib.embedded.RRuntimeError: Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-12-2017da8319ba>", line 14, in <module>
    lines = [line_order_no_endpoints(centroids=labeled_centroids_np, active_classes=np.array(line)) for line in find_lines_R_multiD(dat=labeled_training_data_np, labels=labeled_training_data[1] , dims=dimensions, centroids=labeled_centroids_np, k=total_lines)]
  File "<ipython-input-4-5bd8ab90ad22>", line 499, in find_lines_R_multiD
    get_ipython().magic('R -i df -i k -i max_diff -i dims -o result1 result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)')
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2160, in magic
    return self.run_line_magic(magic_name, magic_arg_s)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2081, in run_line_magic
    result = fn(*args,**kwargs)
  File "<decorator-gen-119>", line 2, in R
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py", line 188, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 783, in R
    raise e
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 756, in R
    text_result, result, visible = self.eval(line)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 273, in eval
    warning_or_other_msg)
rpy2.ipython.rmagic.RInterpreterError: Failed to parse and evaluate line 'result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)'.
R error message: 'Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments'

The code in R looks is

%%R
library(MASS)

xwx <- function(xtx_in, x, w) {
  v <- x[w == 0, , drop = F]
  
  if (sum(w == 0) >= ncol(x) | sum(w == 0) == 0) {
    #print(sum(w == 1))
    tryCatch({solve(t(x[w == 1, ]) %*% x[w == 1, ])},
             error = function(e) {ginv(t(x[w == 1, ]) %*% x[w == 1, ])},
             finally = {})
  } else {
      tryCatch({xtx_in + xtx_in %*% t(v) %*% solve(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in},
               error = function(e) {xtx_in + xtx_in %*% t(v) %*% ginv(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in})
  } 
}

xwy <- function(x, y, w) {
  if (sum(w == 0) == 0) {
    t(x) %*% y
  } else {
    t(x) %*% y - t(x[w == 0, ]) %*% y[w == 0]
  }
}

beta_w <- function(xtx_in, x, y, w) {
  xwx(xtx_in, x, w) %*% xwy(x, y, w)
}

two_norm <- function(a, b) {
  sqrt(sum((a - b)^2))
}


group_classes <- function(data, label, k) {
  mu <- t(sapply(unique(label), function(ii) {
    colMeans(data[label == ii, , drop = F])
    }))
  
  mu_dist <- dist(mu)
  cluster <- cutree(hclust(mu_dist, method = "complete"), k = k)
  
  mu2 <- t(sapply(unique(cluster), function(ii) {
    colMeans(mu[cluster == ii, , drop = F])
  }))
  
  dist2 <- as.matrix(dist(mu2))
  
  jj <- 1
  while (jj <= length(unique(cluster))) {
    #print(length(unique(cluster)))
    #print(jj)
    if (table(cluster)[jj] == 1) {
      new_cluster <- which(rank(dist2[jj, ]) == 2)
      cluster[cluster == jj] <- new_cluster
    }
    jj <- jj + 1
  }
  # print(cluster)
  return(cluster)
}

furthest_classes <- function(data, label, classes) {
  mu <- t(sapply(classes, function(ii) {
    colMeans(data[label == ii,  , drop = F])
  }))
  mu_dist <- as.matrix(dist(mu))
  furthest <- which(mu_dist == max(mu_dist), arr.ind = T)[1, ]
  return(classes[furthest])
}

add_classes <- function(data, label, classes, max_diff = 1.5) {
  
  # two furthest-apart classes in the group
  # we initially fit a line that pierces through both of their centroids
  furthest <- furthest_classes(data, label, classes)
  rest <- classes[!(classes %in% furthest)]
  
  x <- cbind(1, data[, -ncol(data)])
  y <- data[, ncol(data)]
  
  xtx_in <- ginv(t(x) %*% x)
  w <- ifelse(label %in% furthest, 1, 0)
  beta <- beta_w(xtx_in, x, y, w)

  while (length(rest) > 0) {
    
    # for the remaining classes, fit a regression line with it and
    # only classes currently in 'furthest' list
    beta_list <- lapply(rest, function(ii) {
      w <- ifelse(label %in% c(ii, furthest), 1, 0)
      return(beta_w(xtx_in, x, y, w))
    })
    
    # compare the distance between the regression line with the initial two furthest classes
    # and the newly fitted regression line
    distance <- sapply(beta_list, function(a) two_norm(a, beta))
    
    # stop if the smallest difference between the two regression lines is 
    # greater than the max tolerance
    if (all(distance > max_diff)) {
      rest <- integer(0)
    } else {
      
      # otherwise, include the class whose addition resulted in the smallest change
      # in the original regression line
      add <- which.min(distance)[1]
      furthest <- c(furthest, rest[add])
      rest <- rest[-add]
    }
  }
  
  return(list(group = unique(furthest), line = beta))
}


order_classes <- function(data, label, group) {
  # first two elements in group must be the furthest away.
  # this will be the case if group comes from recursive regression
  
  if (length(group) == 1) {
    return(group)
  } else {
    temp <- sapply(group[-1], function(ii) {
      a <- colMeans(data[label == group[1], , drop = F])
      b <- colMeans(data[label == ii, , drop = F])
      return(sum((a - b)^2))
    })
    return(c(group[1], group[-1][order(temp)]))
  }
}



recursive_reg <- function(data, label, k, max_diff = 1e-5, keep_all = T) {
  
  # group the class-wise centroids into k groups
  init_group <- group_classes(data, label, k)
  k_new <- length(unique(init_group))
  
  #if (k_new == 1) {
  #  val <- list(group = order_classes(data, label, 1),
  #              line = lm())
  #}
  # for each of the k groups, find a line that incorporates
  # as many of the classes in that group as possible
  val <- lapply(sort(unique(init_group)), function(ii) {
    classes <- which(init_group == ii)
    # print(classes)
    if (keep_all) {
      temp <- add_classes(data, label, classes, max_diff)
      temp$group <- order_classes(data, label, temp$group)
      return(temp$group)
    } else {
      if (length(unique(classes)) == 1) {
        return(NULL)
      } else {
        temp <- add_classes(data, label, classes, max_diff)
        temp$group <- order_classes(data, label, temp$group)
        return(temp$group)
      }
    }
    #add_classes(data, label, classes, max_diff)
    #if (length(unique(classes)) == 1) {
    #  return(NULL)
    #} else {
    #  add_classes(data, label, classes, max_diff)
    #}
  })
  
  if (keep_all) {
    # if keep_all = T, keep lines from  single classes
    return(val)
  } else {
    # If keep_all = F, filter out groups with only one class
    return(val[lengths(val) != 0])
  }
}

@avyavkumar
Copy link
Author

It looks like the failing matrix multiplication has dimensions

[1] 1 1
[1] 768   1

Is there a workaround for this? Please let me know if so. I added a print statement in

beta_w <- function(xtx_in, x, y, w) {
  print(dim(xwx(xtx_in, x, w)))
  print(dim(xwy(x, y, w)))
  xwx(xtx_in, x, w) %*% xwy(x, y, w)
}

@avyavkumar
Copy link
Author

The dataframe looks like

             0         1         2         3         4         5         6  \
0    -0.285893  0.100989 -0.086276 -0.055642  0.386805 -0.071219  0.539185   
1     0.229483  0.433564 -0.169166 -0.373750  0.058643  0.156620  0.517695   
2     0.507505  0.397592 -0.533612  0.012184  0.141688  0.498529  0.245283   
3     0.108507 -0.343783  0.011003 -0.334807  0.436632 -0.295442 -0.064393   
4    -0.092536  0.101324  0.381015 -0.029415 -0.037573  0.209534  0.321506   
...        ...       ...       ...       ...       ...       ...       ...   
6812  0.081752 -0.076248  0.552696 -0.248589  0.123113 -0.457422 -0.211575   
6813 -0.002597  0.236071  0.211104 -0.253087  0.099704 -0.125046  0.015364   
6814  0.292048 -0.054196  0.459010 -0.343681  0.242175 -0.307340  0.013871   
6815  0.376926 -0.285080  0.137277 -0.225266  0.508968  0.213983  0.056722   
6816  0.009812 -0.133948  0.004090  0.038559  0.166507 -0.004390 -0.031691   

             7         8         9  ...       759       760       761  \
0     0.226831 -0.631589  0.104056  ...  0.060103  0.287790 -0.302842   
1     0.084167 -0.619955 -0.031069  ...  0.312966  0.111479 -0.320775   
2     0.426305  0.059289 -0.223517  ... -0.560820  0.158381 -0.133150   
3     0.121892 -0.102805  0.160829  ...  0.127226 -0.128728 -0.527352   
4     0.134741 -0.072664  0.136189  ... -0.109515  0.112521  0.195295   
...        ...       ...       ...  ...       ...       ...       ...   
6812  0.433101  0.015748  0.256297  ...  0.014548 -0.457582 -0.392458   
6813  0.658662  0.135492 -0.014497  ...  0.060105 -0.035580 -0.613886   
6814  0.368436  0.219049 -0.113884  ...  0.215659  0.008811 -0.413046   
6815  0.212861  0.120595 -0.306513  ... -0.241388  0.224027 -0.231764   
6816  0.698086 -0.115483 -0.004676  ... -0.162783  0.021009 -0.512451   

           762       763       764       765       766       767  \
0    -0.304429 -0.145017 -0.040142 -0.040846  0.040223  0.044410   
1    -0.097893 -0.072137 -0.256350 -0.168176 -0.390138  0.014747   
2     0.284558 -0.244452 -0.211685  0.412036  0.448472  0.136268   
3    -0.558983 -0.240184 -0.069946 -0.078860  0.027040 -0.251058   
4     0.073837 -0.361740  0.242135 -0.225458 -0.043142 -0.317183   
...        ...       ...       ...       ...       ...       ...   
6812 -0.004268 -0.017532  0.337242 -0.226368 -0.079176  0.671449   
6813 -0.273185 -0.061682  0.669058 -0.042069  0.012758  0.943183   
6814 -0.668200  0.115001  0.616219 -0.091257 -0.084448  0.729236   
6815  0.136355 -0.046832  0.212380 -0.251151  0.809596  0.316020   
6816 -0.257796 -0.306105  0.732763 -0.280010 -0.016456  0.494144   

      My Hopes And Dreams  
0                       0  
1                       0  
2                       0  
3                       0  
4                       0  
...                   ...  
6812                 1592  
6813                 1592  
6814                 1592  
6815                 1592  
6816                 1592  

[6817 rows x 769 columns]

@avyavkumar
Copy link
Author

If possible, some documentation about the different functions like xwx, xwy would be helpful to debug issues - it reduces reliance on the original authors of the paper.

@ilia10000
Copy link
Owner

Nam sent me a fix with additional documentation that I pushed to the repo as a separate file just now. Let me know if it helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants