forked from AmpliconSuite/AmpliconSuite-pipeline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnv_prefilter.py
175 lines (142 loc) · 5.57 KB
/
cnv_prefilter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from collections import defaultdict
import os
from intervaltree import IntervalTree
def merge_intervals(usort_intd, cn_cut=4.5, tol=1):
merged_intd = defaultdict(IntervalTree)
for chrom, usort_ints in usort_intd.items():
# sort ints
sort_ints = sorted([x for x in usort_ints if x[2] > cn_cut])
if not sort_ints:
continue
# merge sorted ints
mi = [sort_ints[0][:2]]
for ival in sort_ints[1:]:
if ival[0] <= mi[-1][1] + tol:
ui = (mi[-1][0], max(ival[1], mi[-1][1]))
mi[-1] = ui
else:
mi.append(ival)
for x in mi:
merged_intd[chrom].addi(x[0], x[1])
return merged_intd
# takes list of tuples (chrom, start, end, cn)
def compute_cn_median(cnlist, armlen):
cnsum = sum([x[2]-x[1] for x in cnlist])
if cnsum < 0.5 * armlen:
return 2.0
halfn = cnsum/2.0
scns = sorted(cnlist, key=lambda x: x[3])
rt = 0
ccn = 0
for x in scns:
ccn = x[3]
rt += (x[2] - x[1])
if rt >= halfn:
break
return ccn
def read_bed(ifname, keepdat=False):
beddict = defaultdict(IntervalTree)
with open(ifname) as infile:
for line in infile:
line = line.rstrip()
if line:
fields = line.rsplit()
s, e = int(fields[1]), int(fields[2])
if e - s == 0:
print("Size 0 interval found. Skipping: " + line)
continue
if keepdat:
beddict[fields[0]].addi(s, e, tuple(fields[3:]))
else:
beddict[fields[0]].addi(s, e)
return beddict
# read regions to split on/filter into dictionary of interval trees, where keys are chromosomes
def read_gain_regions(ref):
AA_DATA_REPO = os.environ["AA_DATA_REPO"] + "/" + ref + "/"
fdict = {}
with open(AA_DATA_REPO + "file_list.txt") as infile:
for line in infile:
line = line.rstrip()
if line:
fields = line.rsplit()
fdict[fields[0]] = fields[1]
grf = AA_DATA_REPO + fdict["conserved_regions_filename"]
gain_regions = read_bed(grf)
return gain_regions
def get_continuous_high_regions(bedfile, cngain):
raw_input = defaultdict(list)
with open(bedfile) as infile:
for line in infile:
fields = line.rstrip().rsplit("\t")
c, s, e = fields[0], int(fields[1]), int(fields[2]) + 1
cn = float(fields[-1])
raw_input[c].append((s,e,cn))
return merge_intervals(raw_input, cn_cut=cngain, tol=300000)
# take CNV calls (as bed?) - have to update to not do CNV_GAIN
#input bed file, centromere_dict
#output: path of prefiltered bed file
def prefilter_bed(bedfile, ref, centromere_dict, chr_sizes, cngain, outdir):
# interval to arm lookup
region_ivald = defaultdict(IntervalTree)
for key, value in chr_sizes.items():
try:
cent_tup = centromere_dict[key]
region_ivald[key].addi(0, int(cent_tup[0]), key + "p")
region_ivald[key].addi(int(cent_tup[1]), int(value), key + "q")
# handle mitochondrial contig or other things (like viral genomes)
except KeyError:
region_ivald[key].addi(0, int(value), key)
# store cnv calls per arm
arm2cns = defaultdict(list)
arm2lens = defaultdict(int)
with open(bedfile) as infile:
for line in infile:
fields = line.rstrip().rsplit("\t")
c, s, e = fields[0], int(fields[1]), int(fields[2]) + 1
cn = float(fields[-1])
a = region_ivald[c][(s + e)//2]
if not a:
a = region_ivald[c][s:e]
if a:
carm_interval = a.pop()
carm = carm_interval.data
arm2cns[carm].append((c, s, e, cn))
arm2lens[carm] = carm_interval.end - carm_interval.begin
else:
arm2cns["other"].append((c, s, e, cn))
# print("Warning: could not match " + c + ":" + str(s) + "-" + str(e) + " to a known chromosome arm!")
continuous_high_region_ivald = get_continuous_high_regions(bedfile, cngain)
cn_filt_entries = []
for a in sorted(arm2cns.keys()):
# compute the median CN of the arm
init_cns = arm2cns[a]
med_cn = compute_cn_median(init_cns, arm2lens[a])
for x in init_cns:
ccg = cngain
continuous_high_hits = continuous_high_region_ivald[x[0]][x[1]:x[2]]
if continuous_high_hits:
for y in continuous_high_hits:
if y.end - y.begin > 10000000:
ccg *= 1.5
break
if x[3] > med_cn + ccg - 2:
cn_filt_entries.append(x)
elif ref == "GRCh38_viral" and not x[0].startswith("chr") and x[3] > 0.1:
cn_filt_entries.append(x)
gain_regions = read_gain_regions(ref)
# now remove regions based on filter regions
final_filt_entries = []
for x in cn_filt_entries:
cit = IntervalTree()
cit.addi(x[1], x[2])
bi = gain_regions[x[0]]
for y in bi:
cit.slice(y.begin)
cit.slice(y.end)
for p in sorted(cit):
final_filt_entries.append((x[0], p[0], p[1], x[3]))
bname = outdir + "/" + bedfile.rsplit("/")[-1].rsplit(".bed")[0] + "_pre_filtered.bed"
with open(bname, 'w') as outfile:
for entry in final_filt_entries:
outfile.write("\t".join([str(x) for x in entry]) + "\n")
return bname