-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
66 lines (55 loc) · 2.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
import pandas as pd
import pprint, logging
import seaborn as sns
import matplotlib.pyplot as plt
#plotting modules
from plotting import violin_plot, corr_matrix, KDE_scatter_plot
from logistic_regression import log_reg_probability
# extract json to data
with open('output.json', 'r') as file:
data = json.load(file)
#print(len(data.items()))
# Flattening the data and extracting relevant parts
flattened_data = []
for title, content in data.items():
for key in ['positive', 'negative']:
entry = content[key]['data']
entry['topic'] = title
entry['argument_type'] = key
#pprint.pprint(entry)
flattened_data.append(entry)
# Creating a DataFrame
df = pd.DataFrame(flattened_data)
#calculates mean by conlumn; also cleans up the data
def calc_mean(df):
for label in df.columns:
if label in ('argument_type', 'topic', 'justification'):
pass
else:
if df[label].dtype == 'object':
try:
#some of gpt's output is a little messed up so we replace the string explanations with NaN
df[label] = pd.to_numeric(df[label],errors='coerce')
except Exception as e:
logging.error(f"Error in converting {label} to numbers: {e}")
#perform operation on original object; replace NaN with mean value
df.fillna({label: df[label].mean()}, inplace=True)
if pd.api.types.is_numeric_dtype(df[label]):
mean_value = df[label].mean()
mean_value_by_type = df.groupby('argument_type')[label].mean()
median_value = df[label].median()
median_value_by_type = df.groupby('argument_type')[label].median()
'''
print(f"{label}:")
print(f"mean: {mean_value}")
print(f" {mean_value_by_type}")
print(f"median: {median_value}")
print(f" {median_value_by_type}")
print("------------")
'''
calc_mean(df)
violin_plot(df)
corr_matrix(df)
y_prob = log_reg_probability(df)
KDE_scatter_plot(df, y_prob)