Dataset
UniversalDataset
- class epilearn.data.dataset.UniversalDataset(x=None, states=None, y=None, graph=None, dynamic_graph=None, edge_index=None, edge_weight=None, edge_attr=None)
UniversalDataset class is designed to handle various types of graph data, enabling operations on datasets that include features, states, dynamic graphs, and edge attributes.
- Parameters:
x (torch.Tensor, optional) – Node features tensor of shape (num_samples, num_nodes, num_features). Represents the node features over multiple timesteps.
states (torch.Tensor, optional) – Tensor representing various states of nodes, similar in structure to node features.
y (torch.Tensor, optional) – Tensor representing target labels or values for each node, structured similar to node features.
graph (torch.Tensor or scipy.sparse matrix, optional) – Static graph structure as an adjacency matrix.
dynamic_graph (torch.Tensor, optional) – Dynamic graph information over time, providing evolving adjacency matrices.
edge_index (torch.LongTensor, optional) – Tensor containing edge indices, typically of shape (2, num_edges), for defining which nodes are connected.
edge_weight (torch.Tensor, optional) – Edge weights corresponding to the edge_index, providing the strength or capacity of connections.
edge_attr (torch.Tensor, optional) – Attributes or features for each edge, aligned with the structure defined in edge_index.
- download()
Download selected files of the dataset.
- generate_dataset(X=None, Y=None, states=None, dynamic_adj=None, lookback_window_size=1, horizon_size=1, permute=False, feat_idx=None, target_idx=None)
Takes node features for the graph and divides them into multiple samples along the time-axis by sliding a window of size (num_timesteps_input+ num_timesteps_output) across it in steps of 1. :param X: Node features of shape (num_vertices, num_features, num_timesteps) :return:
Node features divided into multiple samples. Shape is
(num_samples, num_vertices, num_features, num_timesteps_input). - Node targets for the samples. Shape is (num_samples, num_vertices, num_features, num_timesteps_output).
- save()
Save current dataset.