-
Notifications
You must be signed in to change notification settings - Fork 0
/
analysis_smooth_pursuit.py
282 lines (230 loc) · 13.6 KB
/
analysis_smooth_pursuit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
This module contains helper functions used for compiling, analysing and plotting Smooth Pursuit trial data.
For more details refer to our project and pre-registration at https://osf.io/qh8kx/
"""
import os
import pandas as pd
import numpy as np
import ffmpeg
import matplotlib.pyplot as plt
import seaborn as sns
import statistics
from scipy.stats.mstats import winsorize
from analysis_module import *
import ruptures
from scipy.stats import circmean as cim, circstd as cis
from sklearn.linear_model import LinearRegression
from scipy.signal import savgol_filter
class Smooth_Pursuit():
angles = [0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]
def __init__(self, subb, show=True):
self.subb = subb
self.df = pd.read_csv(f"Subjects/{subb}/data.csv")
self.task_df = self.df[self.df["Task_Name"] == "2. Smooth Pursuit"]
self.palette = sns.color_palette('colorblind', len(Smooth_Pursuit.angles))
if show:
print("sp_df summary \n {}".format(self.summary()))
def summary(self):
print(self.task_df.iloc[0].dropna()[7:22])
def parse_trials(self, model,colx = 'pred_x', coly = 'pred_y', show = True, model_outputs = False):
trial_x = {key:[] for key in Smooth_Pursuit.angles}
trial_y = {key:[] for key in Smooth_Pursuit.angles}
onsets = {key:[] for key in Smooth_Pursuit.angles}
durations = {key:[] for key in Smooth_Pursuit.angles}
for _,row in self.task_df.iterrows():
seq = [int(i) for i in row.angles.split(";")]
rec_id = row.rec_session_id
# fname = "Subjects/s{}/{}/{}".format(self.subb,rec_id,row.pursuit_rec)
fname = f"Subjects/{self.subb}/{rec_id}/blockNr_{row.Block_Nr}_taskNr_{row.Task_Nr}_trialNr_{row.Trial_Nr}_pursuit_rec.webm"
c = get_frame_count(fname)
vid_len = float(row.RecStop - row.RecStart)
fps = c/vid_len #in ms
start_times = row.anim_time.split("---values=")[1].strip("\"").split(";")
start_times = [int(t) - int(row.RecStart) for t in start_times if len(t)>1]
stop_times = row.ResetTimes.split("---values=")[1].strip("\"").split(";")
stop_times = [int(t) - int(row.RecStart) for t in stop_times if len(t)>1]
click_times = row.ClickTimes.split("---values=")[1].strip("\"").split(";")
click_times = [int(t) - int(row.RecStart) for t in click_times if len(t)>1]
l = [(int(i),int(j)) for i,j in zip(time_to_frame(start_times, fps), time_to_frame(stop_times, fps))]
click_frames = time_to_frame(click_times, fps)
if show:
try:
print("Recording length (ffmpeg): ",ffmpeg.probe(fname)["format"]["duration"])
except:
pass
print("RecStop - RecStart : ",vid_len)
print("Total Frame Count : ",c)
print("used Frames Count : ",sum([pt[1]-pt[0] for pt in l]))
print("FPS : ",fps*1000)
print("start times : ",start_times)
print("stop times : ",stop_times)
print("click_times : ", click_times)
print("diff : ", [int(i)-int(j) for i,j in zip(stop_times,start_times)])
if model_outputs:
pred_df = pd.read_csv(os.path.join(model.value, f"{self.subb}/model_outputs/Block_{row.Block_Nr}/Smooth Pursuit{row.Trial_Id}.csv"))
else:
pred_df = pd.read_csv(os.path.join(model.value, f"{self.subb}/pred_allcalib/Block_{row.Block_Nr}/Smooth Pursuit{row.Trial_Id}.csv"))
for index,pt in enumerate(l):
sub = pred_df[pred_df.frame.between(pt[0],pt[1])] # movement duration (animation start (pt[0]) -> animation stop(pt[1]))
trial_x[seq[index]].append(sub[colx])
trial_y[seq[index]].append(sub[coly])
try:
sub2 = pred_df[pred_df.frame.between(click_frames[index],pt[1])] # from user click to movement stop
win_len = len(sub2)//3 #ref for parameter determination https://doi.org/10.1016/j.rinp.2018.08.033
win_len = win_len+1 if win_len%2 == 0 else win_len
dist = (np.diff(apply_filter(sub2[colx], win_len=win_len),1)**2+np.diff(apply_filter(sub2[coly], win_len=win_len),1)**2)**(1/2)
algo = ruptures.Dynp(model="rbf", min_size=3, jump = 1).fit(dist)
result = algo.predict(n_bkps=2) #onset and offset
result = [r+1 for r in result] #correcting for the size reduction by 1 when diff is calculated
onset_time = frame_to_time(result[:1], fps)[0]
onset_time = click_times[index]+onset_time - start_times[index]
onsets[seq[index]].append(onset_time) # detected onset time from movement start in ms, rejected if onset_time<70ms
durations[seq[index]].append(frame_to_time([result[1]-result[0]], fps)[0]) #smooth pursuit duration
except Exception as e:
print(e)
onsets[seq[index]].append((np.nan, start_times[index]))
print(f"{self.subb} adding nan to pt {seq[index]} trial {index}")
if show:
plt.scatter(sub[colx],sub[coly], color = self.palette[Smooth_Pursuit.angles.index(seq[index])])
if show:
# plt.xlim(0,max(1600,sub.poly_x.max()))
# plt.ylim(0,max(900,sub.poly_y.max()))
plt.gca().invert_yaxis()
plt.show()
print("-"*50)
return trial_x,trial_y, onsets, durations
# Analysis functions
def apply_filter(data,win_len=15, moving_avg =False):
if moving_avg:
return data.rolling(win_len).mean().dropna()
return savgol_filter(data, window_length = win_len, polyorder=1, deriv=0, mode='nearest', cval=0.0)
def process_trials(trial_x, trial_y, angles, show = False):
avg = {k:[] for k in angles}
for angle in trial_x.keys():
for trial in range(10):
# if trial_x[angle][trial].shape[0] < 15: #Handling with min trial sample = window_size
# continue
#rolling mean
try:
win_len = len(trial_x[angle][trial])//2
win_len = win_len+1 if win_len%2 == 0 else win_len
sm_x = apply_filter(trial_x[angle][trial], win_len = win_len)
sm_y = apply_filter(trial_y[angle][trial], win_len = win_len)
except Exception as e:
print(e)
continue
model = LinearRegression()
X = np.array(sm_x).reshape(-1,1)
Y = sm_y
model.fit(X,Y)
pred = model.predict(X)
math.degrees(math.atan(model.coef_))
ang = np.rad2deg(np.arctan2(pred[-1] - pred[0], X[-1] - X[0]))
ang = -ang if ang<0 else 360-ang #invert y-axis
if show:
plt.scatter(X,Y)
plt.scatter(trial_x[angle][trial], trial_y[angle][trial], color = "black")
plt.plot(X,pred, color = "orange")
plt.gca().invert_yaxis()
plt.show()
print("ang ",ang)
avg[angle].append(ang[0])
return avg
def mean_angle_preds(trial_x,trial_y, angles, show= False):
reg_angles = process_trials
(trial_x,trial_y, angles)#, show=True)
cmean,diff,cstd = {},{},{}
for angle in angles:
win_angles = circ_winsorize(reg_angles[angle],angle) #winsorize around ref angle
cmean[angle] = cim(win_angles, high=360, low=0) #Filter Outliers ? trimmed mean oder std rule.... update: implemented circ winsorize
diff[angle] = abs(cmean[angle]-angle) #only magnitude
cstd[angle] = cis(win_angles, high=360, low=0)
if show:
print("*"*50)
print("angle: ",angle)
print("orig: ",np.int64(reg_angles[angle]))
print("wins: ",np.int64(win_angles))
print("diff: ", diff[angle] if diff[angle]<180 else 360-diff[angle])
print("mean: ", cmean[angle])
print("orig mean: ",cim(reg_angles[angle], high=360, low=0))
print("stdv: ", cstd[angle])
return cmean,diff,cstd
#circular means
def get_win_sub(SP_cmeans, std_error = False):
win_sub_cmean = {}
win_sub_cse = {}
win_sub_025 = {}
win_sub_975 = {}
for col in SP_cmeans.columns:
win_means = circ_winsorize(SP_cmeans[col],col)
win_sub_cmean[col] = cim(win_means, high=360, low=0) #winsorize circular mean
if col==0 and win_sub_cmean[col]>180:
win_sub_cmean[col] = win_sub_cmean[col] - 360
win_sub_cse[col] = cis(win_means, high=360, low=0) #winsorize circular std
if std_error:
win_sub_cse[col] /= math.sqrt(SP_cmeans.shape[0]) #calculate std error instead
win_means = np.array([x-360 if x-col > 180 else x for x in win_means])
win_sub_025[col] = win_sub_cmean[col] - np.quantile(win_means, 0.025)
win_sub_975[col] = np.quantile(win_means, 0.975) - win_sub_cmean[col]
return win_sub_cmean, win_sub_cse, win_sub_025, win_sub_975
#Plotting Functions
def sp_plot(ax, angles, win_sub_cmean, win_sub_025, win_sub_975, color_line = "black", color_err = "teal"):
ax.plot([0,330],[0,330],linestyle="--", lw=4, color = color_line)
#Confidence interval for error bars
ax.errorbar(angles, win_sub_cmean.values(), yerr = [list(win_sub_025.values()), list(win_sub_975.values())], ms = 10, linestyle = "None",elinewidth=4, marker="o", color = color_err)
ax.set_xticks(angles)
ax.set_yticks(angles)
ax.tick_params(axis = 'both', labelsize=15)
# ax.set_xlabel("target movement angle", fontsize=18)
# ax.set_ylabel("mean predicted angle", fontsize=18)
# plt.savefig("smooth_pursuit_faze.png")
return ax
def sp_plot_single_trial(subb, block, trial, angle, colx = 'pred_x', coly = 'pred_y'):
df = pd.read_csv(f"csv_backup/example_trials/SP_subject_data.csv")
task_df = df[df["Task_Name"] == "2. Smooth Pursuit"]
row = task_df[(task_df["Trial_Nr"]==trial) & (task_df["Block_Nr"] == block)].iloc[0]
palette = sns.color_palette('colorblind', 3)
seq = [int(i) for i in row.angles.split(";")]
index = seq.index(angle)
# fname = f"Subjects/{subb}/{row.rec_session_id}/blockNr_{row.Block_Nr}_taskNr_{row.Task_Nr}_trialNr_{row.Trial_Nr}_pursuit_rec.webm"
# c = get_frame_count(fname)
# vid_len = float(row.RecStop - row.RecStart)
fps = 0.030 #c/vid_len (using constant value for plot) ##in ms/sec
start_times = row.anim_time.split("---values=")[1].strip("\"").split(";")
start_times = [int(t) - int(row.RecStart) for t in start_times if len(t)>1]
stop_times = row.ResetTimes.split("---values=")[1].strip("\"").split(";")
stop_times = [int(t) - int(row.RecStart) for t in stop_times if len(t)>1]
click_times = row.ClickTimes.split("---values=")[1].strip("\"").split(";")
click_times = [int(t) - int(row.RecStart) for t in click_times if len(t)>1]
l = [(int(i),int(j)) for i,j in zip(time_to_frame(start_times, fps), time_to_frame(stop_times, fps))]
click_frames = time_to_frame(click_times, fps)
# try:
# print("Recording length (ffmpeg): ",ffmpeg.probe(fname)["format"]["duration"])
# except:
# pass
# print("file: ", fname)
# print("RecStop - RecStart : ",vid_len)
# print("Total Frame Count : ",c)
# print("used Frames Count : ",sum([pt[1]-pt[0] for pt in l]))
# print("FPS : ",fps*1000)
# print("start times : ",start_times)
# print("stop times : ",stop_times)
# print("click_times : ", click_times)
# print("diff : ", [int(i)-int(j) for i,j in zip(stop_times,start_times)])
for i,model in enumerate([pred_path.MPII, pred_path.ETH, pred_path.FAZE]):
print(model)
pred_df = pd.read_csv(os.path.join(model.value, "smooth_pursuit/Smooth Pursuit.csv"))
sub = pred_df[pred_df.frame.between(l[index][0],l[index][1])] # movement duration (animation start (pt[0]) -> animation stop(pt[1]))
sub2 = pred_df[pred_df.frame.between(click_frames[index],l[index][1])] # from user click to movement stop
win_len = len(sub2)//3 #ref for parameter determination https://doi.org/10.1016/j.rinp.2018.08.033
win_len = win_len+1 if win_len%2 == 0 else win_len
dist = (np.diff(apply_filter(sub2[colx], win_len=win_len),1)**2+np.diff(apply_filter(sub2[coly], win_len=win_len),1)**2)**(1/2) #calculate derivative/vel after smoothing
#Change point detection with dynamic programming, no. of breakpoints = 2 [onset and offset of SP]
algo = ruptures.Dynp(model="rbf", min_size=3, jump = 1).fit(dist)
result = algo.predict(n_bkps=2)
result = [r+1 for r in result] #correcting for diff reducing one sample
print(result[:-1], f"no. of samples : {result[-1]}")
sub2[colx].reset_index(drop=True).plot(figsize=(20,7), marker="o", color=palette[i])
plt.axvline(result[0], linestyle = "--", color = palette[i])
plt.axvline(result[1]+(0.08*(i-1)), linestyle = "-", color = palette[i], alpha=1)
return