diff --git a/CHANGELOG.md b/CHANGELOG.md index addce364df6d..03d75f244731 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918)) +- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975)) - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) - Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index 8a28d7359de6..32aec3df5ffb 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -40,6 +40,7 @@ def get_n_params(model): def main(args): gpu = args.gpu dataset_name = args.dataset + text_type = args.text_type if args.dataset == 'arxiv' else 'raw_text' root = osp.join('data', 'ogb') hf_model = args.hf_model pl_ratio = args.pl_ratio @@ -83,7 +84,7 @@ def main(args): tag_dataset = TAGDataset(root, dataset, hf_model, token_on_disk=token_on_disk) - text_dataset = tag_dataset.to_text_dataset() + text_dataset = tag_dataset.to_text_dataset(text_type) print(tag_dataset.num_classes, tag_dataset.raw_file_names) num_classes = tag_dataset.num_classes @@ -393,8 +394,12 @@ def load_model(em_phase): help='number of runs') parser.add_argument('--num_em_iters', type=int, default=1, help='number of iterations') - parser.add_argument("--dataset", type=str, default='products', + parser.add_argument("--dataset", type=str, default='arxiv', help='arxiv or products') + parser.add_argument( + "--text_type", type=str, default='llm_explanation', + help="type of text, support raw_text, llm_explanation," + "all for arxiv and raw_text for products") parser.add_argument("--pl_ratio", type=float, default=0.5, help="pseudo labels ratio") parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', diff --git a/test/datasets/test_tag_dataset.py b/test/datasets/test_tag_dataset.py new file mode 100644 index 000000000000..58b0d3ef66f4 --- /dev/null +++ b/test/datasets/test_tag_dataset.py @@ -0,0 +1,22 @@ +from torch_geometric.datasets import TAGDataset +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('ogb') +def test_tag_dataset() -> None: + from ogb.nodeproppred import PygNodePropPredDataset + + root = './data/ogb' + hf_model = 'prajjwal1/bert-tiny' + token_on_disk = True + + dataset = PygNodePropPredDataset('ogbn-arxiv', root=root) + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + + assert 169343 == tag_dataset[0].num_nodes \ + == len(tag_dataset.text) \ + == len(tag_dataset.llm_explanation) \ + == len(tag_dataset.llm_prediction) + assert 1166243 == tag_dataset[0].num_edges diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index f25992ced989..cbc9fb70c30b 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -1,3 +1,4 @@ +import csv import os import os.path as osp from collections.abc import Sequence @@ -10,6 +11,7 @@ from torch_geometric.data import InMemoryDataset, download_google_url from torch_geometric.data.data import BaseData +from torch_geometric.io import fs try: from pandas import DataFrame, read_csv @@ -22,14 +24,16 @@ class TAGDataset(InMemoryDataset): r"""The Text Attributed Graph datasets from the - `"Learning on Large-scale Text-attributed Graphs via Variational Inference - " `_ paper. + `"Learning on Large-scale Text-attributed Graphs via Variational Inference" + `_ paper and `"Harnessing Explanations: + LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation + Learning" `_ paper. This dataset is aiming on transform `ogbn products`, `ogbn arxiv` into Text Attributed Graph that each node in graph is associate with a - raw text, that dataset can be adapt to DataLoader (for LM training) and - NeighborLoader(for GNN training). In addition, this class can be use as a - wrapper class by convert a InMemoryDataset with Tokenizer and text into - Text Attributed Graph. + raw text, LLM prediction and explanation, that dataset can be adapt to + DataLoader (for LM training) and NeighborLoader(for GNN training). + In addition, this class can be use as a wrapper class by convert a + InMemoryDataset with Tokenizer and text into Text Attributed Graph. Args: root (str): Root directory where the dataset should be saved. @@ -40,6 +44,12 @@ class TAGDataset(InMemoryDataset): on huggingface.co. text (List[str]): list of raw text associate with node, the order of list should be align with node list + llm_explanation (Optional[List[str]]): list of llm explanation + associate with node, which should be align with node list + llm_prediction (Optional[List[str]]): list of llm prediction associate + with node, the order of list should be align with node list + llm_prediction_topk (int): Top K prediction from LLM used as + features for GNN training, default: 5 split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, for saving split index, it is required that if your dataset doesn't have get_split_idx function @@ -51,22 +61,40 @@ class TAGDataset(InMemoryDataset): or not, default: False force_reload (bool): default: False .. note:: - See `example/llm_plus_gnn/glem.py` for example usage + See `example/llm/glem.py` for example usage """ raw_text_id = { 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3', 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt' } - def __init__(self, root: str, dataset: InMemoryDataset, - tokenizer_name: str, text: Optional[List[str]] = None, - split_idx: Optional[Dict[str, Tensor]] = None, - tokenize_batch_size: int = 256, token_on_disk: bool = False, - text_on_disk: bool = False, - force_reload: bool = False) -> None: + llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds' + + llm_explanation_id = { + 'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ', + } + + def __init__( + self, + root: str, + dataset: InMemoryDataset, + tokenizer_name: str, + text: Optional[List[str]] = None, + llm_explanation: Optional[List[str]] = None, + llm_prediction: Optional[Tensor] = None, + llm_prediction_topk: int = 5, + split_idx: Optional[Dict[str, Tensor]] = None, + tokenize_batch_size: int = 256, + token_on_disk: bool = False, + text_on_disk: bool = False, + force_reload: bool = False, + ) -> None: # list the vars you want to pass in before run download & process self.name = dataset.name self.text = text + self.llm_explanation = llm_explanation + self.llm_prediction = llm_prediction + self.llm_prediction_topk = llm_prediction_topk self.tokenizer_name = tokenizer_name from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) @@ -93,8 +121,11 @@ def __init__(self, root: str, dataset: InMemoryDataset, "is_gold mask, please pass splited index " "in format of dictionaty with 'train', 'valid' " "'test' index tensor to 'split_idx'") - if text is not None and text_on_disk: - self.save_node_text(text) + if text_on_disk: + if text is not None: + self.save_node_text(text) + if llm_explanation is not None: + self.save_node_explanation(llm_explanation) self.text_on_disk = text_on_disk # init will call download and process super().__init__(self.root, transform=None, pre_transform=None, @@ -116,9 +147,21 @@ def __init__(self, root: str, dataset: InMemoryDataset, if self.text is not None and len(self.text) != self._data.num_nodes: raise ValueError("The number of text sequence in 'text' should be " "equal to number of nodes!") + if self.llm_explanation is not None and len( + self.llm_explanation) != self._data.num_nodes: + raise ValueError("The number of LLM explanation should be " + "equal to number of nodes!") + if self.llm_prediction is not None and len( + self.llm_prediction) != self._data.num_nodes: + raise ValueError("The number of LLM prediction should be " + "equal to number of nodes!") self.token_on_disk = token_on_disk self.tokenize_batch_size = tokenize_batch_size self._token = self.tokenize_graph(self.tokenize_batch_size) + self._llm_explanation_token = self.tokenize_graph( + self.tokenize_batch_size, text_type='llm_explanation') + self._all_token = self.tokenize_graph(self.tokenize_batch_size, + text_type='all') self.__num_classes__ = dataset.num_classes @property @@ -128,7 +171,7 @@ def num_classes(self) -> int: @property def raw_file_names(self) -> List[str]: file_names = [] - for root, _, files in os.walk(osp.join(self.root, 'raw')): + for _, _, files in os.walk(osp.join(self.root, 'raw')): for file in files: file_names.append(file) return file_names @@ -146,6 +189,19 @@ def token(self) -> Dict[str, Tensor]: self._token = self.tokenize_graph() return self._token + @property + def llm_explanation_token(self) -> Dict[str, Tensor]: + if self._llm_explanation_token is None: # lazy load + self._llm_explanation_token = self.tokenize_graph( + text_type='llm_explanation') + return self._llm_explanation_token + + @property + def all_token(self) -> Dict[str, Tensor]: + if self._all_token is None: # lazy load + self._all_token = self.tokenize_graph(text_type='all') + return self._all_token + # load is_gold after init @property def is_gold(self) -> Tensor: @@ -194,10 +250,17 @@ def download(self) -> None: folder=f'{self.root}/raw', filename='node-text.csv.gz', log=True) - text_df = read_csv(raw_text_path) - self.text = list(text_df['text']) + self.text = list(read_csv(raw_text_path)['text']) + print('downloading llm explanations') + llm_explanation_path = download_google_url( + id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw', + filename='node-gpt-response.csv.gz', log=True) + self.llm_explanation = list(read_csv(llm_explanation_path)['text']) + print('downloading llm predictions') + fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir) def process(self) -> None: + # process Title and Abstraction if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')): text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz')) self.text = list(text_df['text']) @@ -212,6 +275,43 @@ def process(self) -> None: "The raw text of each node is not specified" "Please pass in 'text' when convert your dataset " "to Text Attribute Graph Dataset") + # process LLM explanation and prediction + llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz' + llm_prediction_path = f'{self.raw_dir}/{self.name}.csv' + if osp.exists(llm_explanation_path) and osp.exists( + llm_prediction_path): + # load LLM explanation + self.llm_explanation = list(read_csv(llm_explanation_path)['text']) + # load LLM prediction + preds = [] + with open(llm_prediction_path) as file: + reader = csv.reader(file) + for row in reader: + inner_list = [] + for value in row: + inner_list.append(int(value)) + preds.append(inner_list) + + pl = torch.zeros(len(preds), self.llm_prediction_topk, + dtype=torch.long) + for i, pred in enumerate(preds): + pl[i][:len(pred)] = torch.tensor( + pred[:self.llm_prediction_topk], dtype=torch.long) + 1 + self.llm_prediction = pl + elif self.name in self.llm_explanation_id: + self.download() + else: + print( + 'The dataset is not ogbn-arxiv,' + 'please pass in your llm explanation list to `llm_explanation`' + 'and llm prediction list to `llm_prediction`') + if self.llm_explanation is None or self.llm_prediction is None: + raise ValueError( + "The TAGDataset only have ogbn-arxiv LLM explanations" + "and predictions in default. The llm explanation and" + "prediction of each node is not specified." + "Please pass in 'llm_explanation' and 'llm_prediction' when" + "convert your dataset to Text Attribute Graph Dataset") def save_node_text(self, text: List[str]) -> None: node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz') @@ -224,22 +324,50 @@ def save_node_text(self, text: List[str]) -> None: text_df.to_csv(osp.join(node_text_path), compression='gzip', index=False) - def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: + def save_node_explanation(self, text: List[str]) -> None: + node_text_path = osp.join(self.root, 'raw', 'node-gpt-response.csv.gz') + if osp.exists(node_text_path): + print(f'The llm explanation is existed at {node_text_path}') + else: + print(f'Saving llm explanation file at {node_text_path}') + os.makedirs(f'{self.root}/raw', exist_ok=True) + text_df = DataFrame(text, columns=['text']) + text_df.to_csv(osp.join(node_text_path), compression='gzip', + index=False) + + def tokenize_graph(self, batch_size: int = 256, + text_type: str = 'raw_text') -> Dict[str, Tensor]: r"""Tokenizing the text associate with each node, running in cpu. Args: batch_size (Optional[int]): batch size of list of text for generating emebdding + text_type (Optional[str]): type of text Returns: Dict[str, torch.Tensor]: tokenized graph """ + assert text_type in ['raw_text', 'llm_explanation', 'all'] + if text_type == 'raw_text': + _text = self.text + elif text_type == 'llm_explanation': + _text = self.llm_explanation + elif text_type == 'all': + if self.text is None or self.llm_explanation is None: + raise ValueError("The TAGDataset need text and llm explanation" + "for tokenizing all text") + _text = [ + f'{raw_txt} Explanation: {exp_txt}' + for raw_txt, exp_txt in zip(self.text, self.llm_explanation) + ] + data_len = 0 - if self.text is not None: - data_len = len(self.text) + if _text is not None: + data_len = len(_text) else: raise ValueError("The TAGDataset need text for tokenization") token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] - path = os.path.join(self.processed_dir, 'token', self.tokenizer_name) + path = os.path.join(self.processed_dir, 'token', text_type, + self.tokenizer_name) # Check if the .pt files already exist token_files_exist = any( os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys) @@ -256,12 +384,12 @@ def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: all_encoded_token = {k: [] for k in token_keys} pbar = tqdm(total=data_len) - pbar.set_description('Tokenizing Text Attributed Graph') + pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}') for i in range(0, data_len, batch_size): end_index = min(data_len, i + batch_size) - token = self.tokenizer(self.text[i:min(i + batch_size, data_len)], - padding='max_length', truncation=True, - max_length=512, return_tensors="pt") + token = self.tokenizer(_text[i:end_index], padding='max_length', + truncation=True, max_length=512, + return_tensors="pt") for k in token.keys(): all_encoded_token[k].append(token[k]) pbar.update(end_index - i) @@ -289,10 +417,18 @@ class TextDataset(torch.utils.data.Dataset): Args: tag_dataset (TAGDataset): the parent dataset + text_type (str): type of text """ - def __init__(self, tag_dataset: 'TAGDataset') -> None: + def __init__(self, tag_dataset: 'TAGDataset', + text_type: str = 'raw_text') -> None: + assert text_type in ['raw_text', 'llm_explanation', 'all'] self.tag_dataset = tag_dataset - self.token = tag_dataset.token + if text_type == 'raw_text': + self.token = tag_dataset.token + elif text_type == 'llm_explanation': + self.token = tag_dataset.llm_explanation_token + elif text_type == 'all': + self.token = tag_dataset.all_token assert tag_dataset._data is not None self._data = tag_dataset._data @@ -312,7 +448,8 @@ def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: # for LM training def __getitem__( - self, node_id: IndexType + self, + node_id: IndexType, ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: r"""This function will override the function in torch.utils.data.Dataset, and will be called when you @@ -343,8 +480,8 @@ def get(self, idx: int) -> BaseData: def __repr__(self) -> str: return f'{self.__class__.__name__}()' - def to_text_dataset(self) -> TextDataset: + def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset: r"""Factory Build text dataset from Text Attributed Graph Dataset each data point is node's associated text token. """ - return TAGDataset.TextDataset(self) + return TAGDataset.TextDataset(self, text_type)