-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot.py
106 lines (93 loc) · 3.69 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def load_data(file_path):
"""
Load the data from a CSV file without a header.
Assumes the first column is the class label and the rest are features.
"""
data = pd.read_csv(file_path, header=None, delim_whitespace=True)
return data
def plot_data(data, feature_indices):
"""
Plot the data points based on the specified feature indices.
Colors the points based on class.
"""
class_column = 0 # Class column is always the first column
features = data.columns[1:] # Exclude the class column
print("Class column:", class_column)
print("Feature columns:", features)
print("Feature indices to plot:", feature_indices)
#use a scatter plot for 1D plot
if len(feature_indices) == 1:
plt.figure(figsize=(10, 6))
sns.scatterplot(x=data.iloc[:, feature_indices[0]],
y=0,
hue=data[class_column],
palette={1: 'blue', 2: 'red'},
style=data[class_column],
markers={1: 'o', 2: 's'})
plt.xlabel(features[feature_indices[0] - 1])
plt.ylabel("Class")
plt.title('1D Feature Plot')
plt.show()
if len(feature_indices) == 2:
plt.figure(figsize=(10, 6))
sns.scatterplot(x=data.iloc[:, feature_indices[0]],
y=data.iloc[:, feature_indices[1]],
hue=data[class_column],
palette={1: 'blue', 2: 'red'},
style=data[class_column],
markers={1: 'o', 2: 's'})
plt.xlabel(features[feature_indices[0] - 1])
plt.ylabel(features[feature_indices[1] - 1])
plt.title('2D Feature Plot')
plt.show()
elif len(feature_indices) == 3:
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data.iloc[:, feature_indices[0]],
data.iloc[:, feature_indices[1]],
data.iloc[:, feature_indices[2]],
c=data[class_column].map({1: 'blue', 2: 'red'}))
ax.set_xlabel(features[feature_indices[0] - 1])
ax.set_ylabel(features[feature_indices[1] - 1])
ax.set_zlabel(features[feature_indices[2] - 1])
plt.title('3D Feature Plot')
plt.show()
def main():
data_paths = {
"1": "data/normalizedDataSmall.txt",
"2": "data/normalizedDataLarge.txt",
"3": "normalizedDataSmall.txt",
"4": "normalizedDatalarge.txt"
}
print("Choose a dataset to plot:")
print("1. Plot small general dataset")
print("2. Plot large general dataset")
print("3. Plot small personal dataset")
print("4. Plot large personal dataset")
choice = input("Enter your choice (1-4): ")
if choice not in data_paths:
print("Invalid choice.")
return
file_path = data_paths[choice]
data = load_data(file_path)
# Print data for debugging
print("Data loaded:")
print(data.head())
# Step 2: Ask if the user would like to plot 2 or 3 features
plot_type = input("Would you like to plot 2 or 3 features? (Enter 1 or 2 or 3): ")
if plot_type not in ['1', '2', '3']:
print("Invalid input. Please enter 2 or 3.")
return
# Step 3: Ask which features to plot
num_features = int(plot_type)
feature_indices = []
for i in range(num_features):
feature_index = int(input(f"Enter the index of feature {i+1} to plot (1-based index): "))
feature_indices.append(feature_index)
# Step 4: Plot the data
plot_data(data, feature_indices)
if __name__ == "__main__":
main()