Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llm generated explanations to TAGDataset #9918

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
Expand Down
5 changes: 4 additions & 1 deletion examples/llm/glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_n_params(model):
def main(args):
gpu = args.gpu
dataset_name = args.dataset
text_type = args.text_type
root = osp.join('data', 'ogb')
hf_model = args.hf_model
pl_ratio = args.pl_ratio
Expand Down Expand Up @@ -83,7 +84,7 @@ def main(args):

tag_dataset = TAGDataset(root, dataset, hf_model,
token_on_disk=token_on_disk)
text_dataset = tag_dataset.to_text_dataset()
text_dataset = tag_dataset.to_text_dataset(text_type)
print(tag_dataset.num_classes, tag_dataset.raw_file_names)

num_classes = tag_dataset.num_classes
Expand Down Expand Up @@ -395,6 +396,8 @@ def load_model(em_phase):
help='number of iterations')
parser.add_argument("--dataset", type=str, default='products',
help='arxiv or products')
parser.add_argument("--text_type", type=str, default='raw_text',
help='raw_text, llm_explanation or all')
parser.add_argument("--pl_ratio", type=float, default=0.5,
help="pseudo labels ratio")
parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny',
Expand Down
22 changes: 22 additions & 0 deletions test/datasets/test_tag_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch_geometric.datasets import TAGDataset
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('ogb')
def test_tag_dataset() -> None:
from ogb.nodeproppred import PygNodePropPredDataset

root = './data/ogb'
hf_model = 'prajjwal1/bert-tiny'
token_on_disk = True

dataset = PygNodePropPredDataset('ogbn-arxiv', root=root)
tag_dataset = TAGDataset(root, dataset, hf_model,
token_on_disk=token_on_disk)

assert 169343 == tag_dataset[0].num_nodes \
== len(tag_dataset.text) \
== len(tag_dataset.llm_explanation) \
== len(tag_dataset.llm_prediction)
assert 1166243 == tag_dataset[0].num_edges
199 changes: 168 additions & 31 deletions torch_geometric/datasets/tag_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
import os
import os.path as osp
from collections.abc import Sequence
Expand All @@ -10,6 +11,7 @@

from torch_geometric.data import InMemoryDataset, download_google_url
from torch_geometric.data.data import BaseData
from torch_geometric.io import fs

try:
from pandas import DataFrame, read_csv
Expand All @@ -22,14 +24,16 @@

class TAGDataset(InMemoryDataset):
r"""The Text Attributed Graph datasets from the
`"Learning on Large-scale Text-attributed Graphs via Variational Inference
" <https://arxiv.org/abs/2210.14709>`_ paper.
`"Learning on Large-scale Text-attributed Graphs via Variational Inference"
<https://arxiv.org/abs/2210.14709>`_ paper and `"Harnessing Explanations:
LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation
Learning" <https://arxiv.org/abs/2305.19523>`_ paper.
This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
into Text Attributed Graph that each node in graph is associate with a
raw text, that dataset can be adapt to DataLoader (for LM training) and
NeighborLoader(for GNN training). In addition, this class can be use as a
wrapper class by convert a InMemoryDataset with Tokenizer and text into
Text Attributed Graph.
raw text, LLM prediction and explanation, that dataset can be adapt to
DataLoader (for LM training) and NeighborLoader(for GNN training).
In addition, this class can be use as a wrapper class by convert a
InMemoryDataset with Tokenizer and text into Text Attributed Graph.

Args:
root (str): Root directory where the dataset should be saved.
Expand All @@ -40,6 +44,12 @@ class TAGDataset(InMemoryDataset):
on huggingface.co.
text (List[str]): list of raw text associate with node, the order of
list should be align with node list
llm_explanation (Optional[List[str]]): list of llm explanation
associate with node, which should be align with node list
llm_prediction (Optional[List[str]]): list of llm prediction associate
with node, the order of list should be align with node list
llm_prediction_topk (int): Top K prediction from LLM used as
features for GNN training, default: 5
split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
for saving split index, it is required that if your dataset doesn't
have get_split_idx function
Expand All @@ -51,22 +61,40 @@ class TAGDataset(InMemoryDataset):
or not, default: False
force_reload (bool): default: False
.. note::
See `example/llm_plus_gnn/glem.py` for example usage
See `example/llm/glem.py` for example usage
"""
raw_text_id = {
'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
}

def __init__(self, root: str, dataset: InMemoryDataset,
tokenizer_name: str, text: Optional[List[str]] = None,
split_idx: Optional[Dict[str, Tensor]] = None,
tokenize_batch_size: int = 256, token_on_disk: bool = False,
text_on_disk: bool = False,
force_reload: bool = False) -> None:
llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'

llm_explanation_id = {
'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',
}

def __init__(
self,
root: str,
dataset: InMemoryDataset,
tokenizer_name: str,
text: Optional[List[str]] = None,
llm_explanation: Optional[List[str]] = None,
llm_prediction: Optional[Tensor] = None,
llm_prediction_topk: int = 5,
split_idx: Optional[Dict[str, Tensor]] = None,
tokenize_batch_size: int = 256,
token_on_disk: bool = False,
text_on_disk: bool = False,
force_reload: bool = False,
) -> None:
# list the vars you want to pass in before run download & process
self.name = dataset.name
self.text = text
self.llm_explanation = llm_explanation
self.llm_prediction = llm_prediction
self.llm_prediction_topk = llm_prediction_topk
self.tokenizer_name = tokenizer_name
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
Expand All @@ -93,8 +121,11 @@ def __init__(self, root: str, dataset: InMemoryDataset,
"is_gold mask, please pass splited index "
"in format of dictionaty with 'train', 'valid' "
"'test' index tensor to 'split_idx'")
if text is not None and text_on_disk:
self.save_node_text(text)
if text_on_disk:
if text is not None:
self.save_node_text(text)
if llm_explanation is not None:
self.save_node_explanation(llm_explanation)
self.text_on_disk = text_on_disk
# init will call download and process
super().__init__(self.root, transform=None, pre_transform=None,
Expand All @@ -116,9 +147,21 @@ def __init__(self, root: str, dataset: InMemoryDataset,
if self.text is not None and len(self.text) != self._data.num_nodes:
raise ValueError("The number of text sequence in 'text' should be "
"equal to number of nodes!")
if self.llm_explanation is not None and len(
self.llm_explanation) != self._data.num_nodes:
raise ValueError("The number of LLM explanation should be "
"equal to number of nodes!")
if self.llm_prediction is not None and len(
self.llm_prediction) != self._data.num_nodes:
raise ValueError("The number of LLM prediction should be "
"equal to number of nodes!")
self.token_on_disk = token_on_disk
self.tokenize_batch_size = tokenize_batch_size
self._token = self.tokenize_graph(self.tokenize_batch_size)
self._llm_explanation_token = self.tokenize_graph(
self.tokenize_batch_size, text_type='llm_explanation')
self._all_token = self.tokenize_graph(self.tokenize_batch_size,
text_type='all')
self.__num_classes__ = dataset.num_classes

@property
Expand All @@ -128,7 +171,7 @@ def num_classes(self) -> int:
@property
def raw_file_names(self) -> List[str]:
file_names = []
for root, _, files in os.walk(osp.join(self.root, 'raw')):
for _, _, files in os.walk(osp.join(self.root, 'raw')):
for file in files:
file_names.append(file)
return file_names
Expand All @@ -146,6 +189,19 @@ def token(self) -> Dict[str, Tensor]:
self._token = self.tokenize_graph()
return self._token

@property
def llm_explanation_token(self) -> Dict[str, Tensor]:
if self._llm_explanation_token is None: # lazy load
self._llm_explanation_token = self.tokenize_graph(
text_type='llm_explanation')
return self._llm_explanation_token

@property
def all_token(self) -> Dict[str, Tensor]:
if self._all_token is None: # lazy load
self._all_token = self.tokenize_graph(text_type='all')
return self._all_token

# load is_gold after init
@property
def is_gold(self) -> Tensor:
Expand Down Expand Up @@ -194,10 +250,17 @@ def download(self) -> None:
folder=f'{self.root}/raw',
filename='node-text.csv.gz',
log=True)
text_df = read_csv(raw_text_path)
self.text = list(text_df['text'])
self.text = list(read_csv(raw_text_path)['text'])
print('downloading llm explanations')
llm_explanation_path = download_google_url(
id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw',
filename='node-gpt-response.csv.gz', log=True)
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
print('downloading llm predictions')
fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)

def process(self) -> None:
# process Title and Abstraction
if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
self.text = list(text_df['text'])
Expand All @@ -212,6 +275,43 @@ def process(self) -> None:
"The raw text of each node is not specified"
"Please pass in 'text' when convert your dataset "
"to Text Attribute Graph Dataset")
# process LLM explanation and prediction
llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'
llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'
if osp.exists(llm_explanation_path) and osp.exists(
llm_prediction_path):
# load LLM explanation
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
# load LLM prediction
preds = []
with open(llm_prediction_path) as file:
reader = csv.reader(file)
for row in reader:
inner_list = []
for value in row:
inner_list.append(int(value))
preds.append(inner_list)

pl = torch.zeros(len(preds), self.llm_prediction_topk,
dtype=torch.long)
for i, pred in enumerate(preds):
pl[i][:len(pred)] = torch.tensor(
pred[:self.llm_prediction_topk], dtype=torch.long) + 1
self.llm_prediction = pl
elif self.name in self.llm_explanation_id:
self.download()
else:
print(
'The dataset is not ogbn-arxiv,'
'please pass in your llm explanation list to `llm_explanation`'
'and llm prediction list to `llm_prediction`')
if self.llm_explanation is None or self.llm_prediction is None:
raise ValueError(
"The TAGDataset only have ogbn-arxiv LLM explanations"
"and predictions in default. The llm explanation and"
"prediction of each node is not specified."
"Please pass in 'llm_explanation' and 'llm_prediction' when"
"convert your dataset to Text Attribute Graph Dataset")

def save_node_text(self, text: List[str]) -> None:
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
Expand All @@ -224,22 +324,50 @@ def save_node_text(self, text: List[str]) -> None:
text_df.to_csv(osp.join(node_text_path), compression='gzip',
index=False)

def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
def save_node_explanation(self, text: List[str]) -> None:
node_text_path = osp.join(self.root, 'raw', 'node-gpt-response.csv.gz')
if osp.exists(node_text_path):
print(f'The llm explanation is existed at {node_text_path}')
else:
print(f'Saving llm explanation file at {node_text_path}')
os.makedirs(f'{self.root}/raw', exist_ok=True)
text_df = DataFrame(text, columns=['text'])
text_df.to_csv(osp.join(node_text_path), compression='gzip',
index=False)

def tokenize_graph(self, batch_size: int = 256,
text_type: str = 'raw_text') -> Dict[str, Tensor]:
r"""Tokenizing the text associate with each node, running in cpu.

Args:
batch_size (Optional[int]): batch size of list of text for
generating emebdding
text_type (Optional[str]): type of text
Returns:
Dict[str, torch.Tensor]: tokenized graph
"""
assert text_type in ['raw_text', 'llm_explanation', 'all']
if text_type == 'raw_text':
_text = self.text
elif text_type == 'llm_explanation':
_text = self.llm_explanation
elif text_type == 'all':
if self.text is None or self.llm_explanation is None:
raise ValueError("The TAGDataset need text and llm explanation"
"for tokenizing all text")
_text = [
f'{raw_txt} Explanation: {exp_txt}'
for raw_txt, exp_txt in zip(self.text, self.llm_explanation)
]

data_len = 0
if self.text is not None:
data_len = len(self.text)
if _text is not None:
data_len = len(_text)
else:
raise ValueError("The TAGDataset need text for tokenization")
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
path = os.path.join(self.processed_dir, 'token', self.tokenizer_name)
path = os.path.join(self.processed_dir, 'token', text_type,
self.tokenizer_name)
# Check if the .pt files already exist
token_files_exist = any(
os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
Expand All @@ -256,12 +384,12 @@ def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
all_encoded_token = {k: [] for k in token_keys}
pbar = tqdm(total=data_len)

pbar.set_description('Tokenizing Text Attributed Graph')
pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')
for i in range(0, data_len, batch_size):
end_index = min(data_len, i + batch_size)
token = self.tokenizer(self.text[i:min(i + batch_size, data_len)],
padding='max_length', truncation=True,
max_length=512, return_tensors="pt")
token = self.tokenizer(_text[i:end_index], padding='max_length',
truncation=True, max_length=512,
return_tensors="pt")
for k in token.keys():
all_encoded_token[k].append(token[k])
pbar.update(end_index - i)
Expand Down Expand Up @@ -289,10 +417,18 @@ class TextDataset(torch.utils.data.Dataset):

Args:
tag_dataset (TAGDataset): the parent dataset
text_type (str): type of text
"""
def __init__(self, tag_dataset: 'TAGDataset') -> None:
def __init__(self, tag_dataset: 'TAGDataset',
text_type: str = 'raw_text') -> None:
assert text_type in ['raw_text', 'llm_explanation', 'all']
self.tag_dataset = tag_dataset
self.token = tag_dataset.token
if text_type == 'raw_text':
self.token = tag_dataset.token
elif text_type == 'llm_explanation':
self.token = tag_dataset.llm_explanation_token
elif text_type == 'all':
self.token = tag_dataset.all_token
assert tag_dataset._data is not None
self._data = tag_dataset._data

Expand All @@ -312,7 +448,8 @@ def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:

# for LM training
def __getitem__(
self, node_id: IndexType
self,
node_id: IndexType,
) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
r"""This function will override the function in
torch.utils.data.Dataset, and will be called when you
Expand Down Expand Up @@ -343,8 +480,8 @@ def get(self, idx: int) -> BaseData:
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

def to_text_dataset(self) -> TextDataset:
def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:
r"""Factory Build text dataset from Text Attributed Graph Dataset
each data point is node's associated text token.
"""
return TAGDataset.TextDataset(self)
return TAGDataset.TextDataset(self, text_type)
Loading